mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 08:22:40 +08:00
Compare commits
84 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3450a89880 | ||
|
|
a081a69bbe | ||
|
|
271d1d23d5 | ||
|
|
605aba1a3c | ||
|
|
be3c2b4c7c | ||
|
|
08eb32d7bd | ||
|
|
2b9cda15e4 | ||
|
|
f6055b290a | ||
|
|
ec665e05e4 | ||
|
|
2b6d7205ec | ||
|
|
41381a920c | ||
|
|
f1b3fc2254 | ||
|
|
a677ed307d | ||
|
|
0ab23ee972 | ||
|
|
43f56d39be | ||
|
|
a39caee5f5 | ||
|
|
2edfdf47c8 | ||
|
|
3819461db5 | ||
|
|
85654dd7dd | ||
|
|
619a70416b | ||
|
|
16d996fe70 | ||
|
|
1baeb6da19 | ||
|
|
1641d432dd | ||
|
|
1bf9862e47 | ||
|
|
602a394043 | ||
|
|
22a2415ca5 | ||
|
|
feb034352d | ||
|
|
a7c8942c78 | ||
|
|
95f2ac3811 | ||
|
|
91354295f2 | ||
|
|
c9c4ab5911 | ||
|
|
a26c5e40dd | ||
|
|
80f5c7bc44 | ||
|
|
4833b39c52 | ||
|
|
f478958943 | ||
|
|
0469ad46d6 | ||
|
|
5fe5deb9df | ||
|
|
ce83bc24bd | ||
|
|
dce729c8cb | ||
|
|
a9d17cd96f | ||
|
|
294bb3d4a1 | ||
|
|
b31b9261f2 | ||
|
|
2211f8d9e4 | ||
|
|
b9b7b00a7f | ||
|
|
843faf6103 | ||
|
|
4af5dad9a8 | ||
|
|
52437c9d18 | ||
|
|
c6cb4c8479 | ||
|
|
c3714ec251 | ||
|
|
dbe2f94af1 | ||
|
|
07fd5f8a9e | ||
|
|
9e64b4cd7f | ||
|
|
f08a7b9eb3 | ||
|
|
a6fa764e2a | ||
|
|
01676668f1 | ||
|
|
8e5e4f460d | ||
|
|
f907b8a84d | ||
|
|
a3a4285f90 | ||
|
|
0979163b79 | ||
|
|
248a25eaee | ||
|
|
f95b1fa68a | ||
|
|
d2b5d69051 | ||
|
|
3ca419b735 | ||
|
|
50e275a2f9 | ||
|
|
aeccf78957 | ||
|
|
cb3cef70e5 | ||
|
|
b9bd303bf8 | ||
|
|
57d4786a7f | ||
|
|
df031455b2 | ||
|
|
30059eff4f | ||
|
|
bc289b48c8 | ||
|
|
067d8b99b8 | ||
|
|
00a6a9c42d | ||
|
|
070425d446 | ||
|
|
7405883444 | ||
|
|
66959937ed | ||
|
|
e431efbcba | ||
|
|
ba00baa5a0 | ||
|
|
0fb5d4a164 | ||
|
|
1ac717b67f | ||
|
|
273cbd447e | ||
|
|
cee41567a2 | ||
|
|
1aae5eb1a6 | ||
|
|
28a4c81aff |
@@ -1,5 +1,3 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@@ -11,11 +9,12 @@ from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessa
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.memory import conversation_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -26,7 +25,9 @@ class AgentChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
"""
|
||||
MoviePilot AI智能体
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
@@ -39,12 +40,6 @@ class MoviePilotAgent:
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
@@ -63,80 +58,37 @@ class MoviePilotAgent:
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
else:
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
"""
|
||||
初始化LLM模型
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
"""
|
||||
初始化工具列表
|
||||
"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler,
|
||||
memory_mananger=self.memory_manager
|
||||
callback_handler=self.callback_handler
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
"""
|
||||
初始化内存存储
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
"""
|
||||
获取会话历史
|
||||
"""
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
@@ -161,14 +113,20 @@ class MoviePilotAgent:
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_message(ToolMessage(content=msg.get("content", "")))
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(ToolMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=metadata.get("call_id", "unknown")
|
||||
))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
"""
|
||||
初始化提示词模板
|
||||
"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
@@ -183,7 +141,9 @@ class MoviePilotAgent:
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
"""
|
||||
创建Agent执行器
|
||||
"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
@@ -210,10 +170,12 @@ class MoviePilotAgent:
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
try:
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
@@ -222,7 +184,7 @@ class MoviePilotAgent:
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
@@ -239,7 +201,7 @@ class MoviePilotAgent:
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
@@ -259,7 +221,9 @@ class MoviePilotAgent:
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
"""
|
||||
执行LangChain Agent
|
||||
"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
@@ -292,7 +256,9 @@ class MoviePilotAgent:
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
@@ -305,24 +271,32 @@ class MoviePilotAgent:
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
"""
|
||||
清理智能体资源
|
||||
"""
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
"""
|
||||
AI智能体管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
await conversation_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
await conversation_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
@@ -330,7 +304,9 @@ class AgentManager:
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
@@ -341,7 +317,6 @@ class AgentManager:
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
@@ -358,12 +333,14 @@ class AgentManager:
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
await conversation_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
"""
|
||||
流式输出回调处理器
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
@@ -14,7 +16,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
"""
|
||||
获取当前消息内容,获取后清空
|
||||
"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
@@ -24,7 +28,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
"""
|
||||
处理新的token
|
||||
"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
|
||||
@@ -12,7 +12,9 @@ from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
"""
|
||||
对话记忆管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
@@ -23,7 +25,9 @@ class ConversationMemoryManager:
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
"""
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
@@ -33,7 +37,9 @@ class ConversationMemoryManager:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
"""
|
||||
关闭记忆管理器
|
||||
"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
@@ -46,56 +52,83 @@ class ConversationMemoryManager:
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
def get_memory_key(session_id: str, user_id: str):
|
||||
"""计算内存Key"""
|
||||
def _get_memory_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算内存Key
|
||||
"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def get_redis_key(session_id: str, user_id: str):
|
||||
"""计算Redis Key"""
|
||||
def _get_redis_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算Redis Key
|
||||
"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
def _get_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
获取内存中的记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
return self.memory_cache.get(cache_key)
|
||||
|
||||
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
从Redis获取记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""
|
||||
获取会话记忆
|
||||
"""
|
||||
# 首先检查缓存
|
||||
conversion = self._get_memory(session_id, user_id)
|
||||
if conversion:
|
||||
return conversion
|
||||
|
||||
# 尝试从Redis加载
|
||||
memory = await self._get_redis(session_id, user_id)
|
||||
if memory:
|
||||
# 加载到内存缓存
|
||||
self._save_memory(memory)
|
||||
return memory
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
设置会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
"""
|
||||
列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
@@ -148,7 +181,7 @@ class ConversationMemoryManager:
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
async def add_conversation(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
@@ -156,8 +189,10 @@ class ConversationMemoryManager:
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
添加消息到记忆
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
@@ -177,7 +212,7 @@ class ConversationMemoryManager:
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
@@ -186,11 +221,12 @@ class ConversationMemoryManager:
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
@@ -205,8 +241,10 @@ class ConversationMemoryManager:
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取最近的消息
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
@@ -215,36 +253,41 @@ class ConversationMemoryManager:
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话上下文
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
"""
|
||||
清空会话记忆
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
def _save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
|
||||
保存记忆到内存
|
||||
"""
|
||||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
async def _save_redis(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到Redis
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
|
||||
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
@@ -255,8 +298,22 @@ class ConversationMemoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _save_conversation(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
self._save_memory(memory)
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
await self._save_redis(memory)
|
||||
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
"""
|
||||
清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
@@ -286,3 +343,5 @@ class ConversationMemoryManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
|
||||
conversation_manager = ConversationMemoryManager()
|
||||
|
||||
@@ -1,70 +1,72 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
You are an AI media assistant powered by MoviePilot, specialized in managing home media ecosystems. Your expertise covers searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
Core Capabilities:
|
||||
1. Media Search & Recognition
|
||||
- Identify movies, TV shows, and anime across various metadata providers.
|
||||
- Recognize media info from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management
|
||||
- Create complex rules for automated downloading of new episodes.
|
||||
- Monitor trending movies/shows for automated suggestions.
|
||||
3. Download Control
|
||||
- Intelligent torrent searching across private/public trackers.
|
||||
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
|
||||
4. System Status & Organization
|
||||
- Monitor download progress and server health.
|
||||
- Manage file transfers, renaming, and library cleanup.
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
<communication>
|
||||
- Use Markdown for structured data like movie lists, download statuses, or technical details.
|
||||
- Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary.
|
||||
- ALWAYS use backticks for media titles (e.g., `Interstellar`), file paths, or specific parameters.
|
||||
- Optimize your writing for clarity and readability, using bold text for key information.
|
||||
- Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions.
|
||||
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
|
||||
|
||||
## Working Principles
|
||||
Important Notes:
|
||||
- User-Centric: Your tone should be helpful, professional, and media-savvy.
|
||||
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
|
||||
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
|
||||
</communication>
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
<status_update_spec>
|
||||
Definition: Provide a brief progress narrative (1-3 sentences) explaining what you have searched, what you found, and what you are about to execute.
|
||||
- **Immediate Execution**: If you state an intention to perform an action (e.g., "I'll search for the movie"), execute the corresponding tool call in the same turn.
|
||||
- Use natural tenses: "I've found...", "I'm checking...", "I will now add...".
|
||||
- Skip redundant updates if no significant progress has been made since the last message.
|
||||
</status_update_spec>
|
||||
|
||||
## Common Operation Workflows
|
||||
<summary_spec>
|
||||
At the end of your session/turn, provide a concise summary of your actions.
|
||||
- Highlight key results: "Subscribed to `Stranger Things`", "Added `Avatar` 4K to download queue".
|
||||
- Use bullet points for multiple actions.
|
||||
- Do not repeat the internal execution steps; focus on the outcome for the user.
|
||||
</summary_spec>
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
<flow>
|
||||
1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?).
|
||||
3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
|
||||
4. Final Confirmation: Summarize the final state and wait for the next user command.
|
||||
</flow>
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
<tool_calling_strategy>
|
||||
- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
|
||||
- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
|
||||
- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
|
||||
</tool_calling_strategy>
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
<media_management_rules>
|
||||
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
|
||||
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
|
||||
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
|
||||
4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
|
||||
</media_management_rules>
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
@@ -1,13 +1,15 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
"""
|
||||
提示词管理器
|
||||
"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
@@ -17,22 +19,20 @@ class PromptManager:
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
"""
|
||||
加载指定的提示词
|
||||
"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
@@ -46,73 +46,43 @@ class PromptManager:
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
|
||||
# 识别渠道
|
||||
msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None
|
||||
if msg_channel:
|
||||
# 获取渠道能力说明
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
base_prompt = base_prompt.replace(
|
||||
"{markdown_spec}",
|
||||
self._generate_formatting_instructions(caps)
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
根据渠道能力动态生成格式指令
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
instructions = []
|
||||
if ChannelCapability.RICH_TEXT not in caps.capabilities:
|
||||
instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.")
|
||||
instructions.append(
|
||||
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).")
|
||||
instructions.append(
|
||||
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.")
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
"""
|
||||
清空缓存
|
||||
"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Optional
|
||||
@@ -6,7 +5,7 @@ from typing import Any, Optional
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
|
||||
from app.agent import StreamingCallbackHandler, conversation_manager
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -17,7 +16,9 @@ class ToolChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
"""
|
||||
MoviePilot专用工具基类
|
||||
"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
@@ -25,7 +26,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -36,15 +36,14 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送和记忆工具调用前的信息
|
||||
"""
|
||||
异步运行工具
|
||||
"""
|
||||
# 获取工具调用前的agent消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
# 发送消息
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
|
||||
# 记忆工具调用
|
||||
await self._memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
@@ -56,15 +55,24 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
}
|
||||
)
|
||||
|
||||
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
# 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
# 合并agent消息和工具执行消息,一起发送
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
# 发送合并后的消息
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message, title="MoviePilot助手")
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
result = await self.run(**kwargs)
|
||||
@@ -73,15 +81,18 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
# 记忆工具调用结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
await self._memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result
|
||||
content=formated_result,
|
||||
metadata={
|
||||
"call_id": self.__class__.__name__
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -106,21 +117,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
"""
|
||||
设置消息属性
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
"""
|
||||
设置回调处理器
|
||||
"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
|
||||
"""设置记忆客理器"""
|
||||
self._memory_manager = memory_manager
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
@@ -27,6 +25,7 @@ from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool
|
||||
from app.agent.tools.impl.recognize_media import RecognizeMediaTool
|
||||
from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
|
||||
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
|
||||
from app.agent.tools.impl.query_media_detail import QueryMediaDetailTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
@@ -46,13 +45,17 @@ from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
"""
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
@@ -61,6 +64,7 @@ class MoviePilotToolFactory:
|
||||
RecognizeMediaTool,
|
||||
ScrapeMetadataTool,
|
||||
QueryEpisodeScheduleTool,
|
||||
QueryMediaDetailTool,
|
||||
AddSubscribeTool,
|
||||
UpdateSubscribeTool,
|
||||
SearchSubscribeTool,
|
||||
@@ -102,7 +106,6 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -125,7 +128,6 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
|
||||
@@ -29,7 +29,8 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
def _get_all_torrents(self, download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
|
||||
@staticmethod
|
||||
def _get_all_torrents(download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
|
||||
"""
|
||||
查询所有状态的任务(包括下载中和已完成的任务)
|
||||
"""
|
||||
|
||||
120
app/agent/tools/impl/query_media_detail.py
Normal file
120
app/agent/tools/impl/query_media_detail.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""查询媒体详情工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class QueryMediaDetailInput(BaseModel):
|
||||
"""查询媒体详情工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the media (movie or TV series)")
|
||||
media_type: str = Field(..., description="Media type: 'movie' or 'tv'")
|
||||
|
||||
|
||||
class QueryMediaDetailTool(MoviePilotTool):
|
||||
name: str = "query_media_detail"
|
||||
description: str = "Query detailed media information from TMDB by ID and media_type. IMPORTANT: Convert search results type: '电影'→'movie', '电视剧'→'tv'. Returns core metadata including title, year, overview, status, genres, directors, actors, and season count for TV series."
|
||||
args_schema: Type[BaseModel] = QueryMediaDetailInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
|
||||
|
||||
async def run(self, tmdb_id: int, media_type: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, media_type={media_type}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
|
||||
mtype = None
|
||||
if media_type:
|
||||
if media_type.lower() == 'movie':
|
||||
mtype = MediaType.MOVIE
|
||||
elif media_type.lower() == 'tv':
|
||||
mtype = MediaType.TV
|
||||
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=mtype)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 精简 genres - 只保留名称
|
||||
genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")]
|
||||
|
||||
# 精简 directors - 只保留姓名和职位
|
||||
directors = [
|
||||
{
|
||||
"name": d.get("name"),
|
||||
"job": d.get("job")
|
||||
}
|
||||
for d in (mediainfo.directors or [])
|
||||
if d.get("name")
|
||||
]
|
||||
|
||||
# 精简 actors - 只保留姓名和角色
|
||||
actors = [
|
||||
{
|
||||
"name": a.get("name"),
|
||||
"character": a.get("character")
|
||||
}
|
||||
for a in (mediainfo.actors or [])
|
||||
if a.get("name")
|
||||
]
|
||||
|
||||
# 构建基础媒体详情信息
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"overview": mediainfo.overview,
|
||||
"status": mediainfo.status,
|
||||
"genres": genres,
|
||||
"directors": directors,
|
||||
"actors": actors
|
||||
}
|
||||
|
||||
# 如果是电视剧,添加电视剧特有信息
|
||||
if mediainfo.type == MediaType.TV:
|
||||
# 精简 season_info - 只保留基础摘要
|
||||
season_info = [
|
||||
{
|
||||
"season_number": s.get("season_number"),
|
||||
"name": s.get("name"),
|
||||
"episode_count": s.get("episode_count"),
|
||||
"air_date": s.get("air_date")
|
||||
}
|
||||
for s in (mediainfo.season_info or [])
|
||||
if s.get("season_number") is not None
|
||||
]
|
||||
|
||||
result.update({
|
||||
"number_of_seasons": mediainfo.number_of_seasons,
|
||||
"number_of_episodes": mediainfo.number_of_episodes,
|
||||
"first_air_date": mediainfo.first_air_date,
|
||||
"last_air_date": mediainfo.last_air_date,
|
||||
"season_info": season_info
|
||||
})
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询媒体详情失败: {str(e)}"
|
||||
logger.error(f"查询媒体详情失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"tmdb_id": tmdb_id
|
||||
}, ensure_ascii=False)
|
||||
@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
@@ -99,7 +100,7 @@ class SearchTorrentsTool(MoviePilotTool):
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"size": StringUtils.format_size(t.torrent_info.size),
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
"""MoviePilot工具管理器
|
||||
用于HTTP API调用工具
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent import ConversationMemoryManager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
"""
|
||||
工具定义
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
|
||||
self.name = name
|
||||
@@ -21,7 +18,9 @@ class ToolDefinition:
|
||||
|
||||
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
"""
|
||||
MoviePilot工具管理器(用于HTTP API)
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
@@ -34,11 +33,12 @@ class MoviePilotToolsManager:
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
"""加载所有MoviePilot工具"""
|
||||
"""
|
||||
加载所有MoviePilot工具
|
||||
"""
|
||||
try:
|
||||
# 创建工具实例
|
||||
self.tools = MoviePilotToolFactory.create_tools(
|
||||
@@ -48,7 +48,6 @@ class MoviePilotToolsManager:
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None,
|
||||
memory_mananger=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -116,7 +115,7 @@ class MoviePilotToolsManager:
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
@@ -124,7 +123,7 @@ class MoviePilotToolsManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
@@ -132,10 +131,10 @@ class MoviePilotToolsManager:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
|
||||
field_info = properties[key]
|
||||
field_type = field_info.get("type")
|
||||
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
@@ -144,27 +143,30 @@ class MoviePilotToolsManager:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
|
||||
normalized[key] = value
|
||||
normalized[key] = None
|
||||
elif field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
|
||||
normalized[key] = value
|
||||
elif field_type == "boolean" and isinstance(value, str):
|
||||
# 转换字符串为布尔值
|
||||
normalized[key] = value.lower() in ("true", "1", "yes", "on")
|
||||
normalized[key] = None
|
||||
elif field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
normalized[key] = value.lower() in ("true", "1", "yes", "on")
|
||||
elif isinstance(value, (int, float)):
|
||||
normalized[key] = value != 0
|
||||
else:
|
||||
normalized[key] = True
|
||||
else:
|
||||
# 其他类型保持原样
|
||||
normalized[key] = value
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
@@ -189,7 +191,7 @@ class MoviePilotToolsManager:
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
@@ -199,7 +201,11 @@ class MoviePilotToolsManager:
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
try:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示")
|
||||
formated_result = str(result)
|
||||
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
@@ -263,3 +269,6 @@ class MoviePilotToolsManager:
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
|
||||
moviepilot_tool_manager = MoviePilotToolsManager()
|
||||
|
||||
@@ -32,11 +32,11 @@ def login_access_token(
|
||||
# 如果是需要MFA验证,返回特殊标识
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
status_code=401,
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 用户等级
|
||||
level = SitesHelper().auth_level
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""工具API端点
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.agent.tools.manager import moviepilot_tool_manager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
@@ -25,18 +21,10 @@ MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
|
||||
def get_tools_manager() -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
return MoviePilotToolsManager()
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 成功响应"""
|
||||
"""
|
||||
创建 JSON-RPC 成功响应
|
||||
"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
@@ -45,8 +33,11 @@ def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> D
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 错误响应"""
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[
|
||||
str, Any]:
|
||||
"""
|
||||
创建 JSON-RPC 错误响应
|
||||
"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
@@ -60,8 +51,6 @@ def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message:
|
||||
return error
|
||||
|
||||
|
||||
# ==================== MCP JSON-RPC 端点 ====================
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
@@ -146,7 +135,9 @@ async def mcp_jsonrpc(
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理初始化请求"""
|
||||
"""
|
||||
处理初始化请求
|
||||
"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
@@ -161,7 +152,7 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
@@ -180,9 +171,10 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""处理工具列表请求"""
|
||||
manager = get_tools_manager()
|
||||
tools = manager.list_tools()
|
||||
"""
|
||||
处理工具列表请求
|
||||
"""
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
@@ -200,18 +192,18 @@ async def handle_tools_list() -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理工具调用请求"""
|
||||
"""
|
||||
处理工具调用请求
|
||||
"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
manager = get_tools_manager()
|
||||
|
||||
try:
|
||||
result_text = await manager.call_tool(tool_name, arguments)
|
||||
|
||||
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
@@ -243,8 +235,6 @@ async def delete_mcp_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
@@ -257,9 +247,8 @@ async def list_tools(
|
||||
返回每个工具的名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具定义
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
@@ -289,11 +278,8 @@ async def call_tool(
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 使用当前用户ID创建管理器实例
|
||||
manager = get_tools_manager()
|
||||
|
||||
# 调用工具
|
||||
result_text = await manager.call_tool(request.tool_name, request.arguments)
|
||||
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
success=True,
|
||||
@@ -319,9 +305,8 @@ async def get_tool_info(
|
||||
工具的详细信息,包括名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
@@ -352,9 +337,8 @@ async def get_tool_schema(
|
||||
工具的JSON Schema定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
|
||||
@@ -24,6 +24,75 @@ from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
def _build_credential_list(passkeys: list[PassKey]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
构建凭证列表
|
||||
|
||||
:param passkeys: PassKey 列表
|
||||
:return: 凭证字典列表
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
|
||||
def _extract_and_standardize_credential_id(credential: dict) -> str:
|
||||
"""
|
||||
从凭证中提取并标准化 credential_id
|
||||
|
||||
:param credential: 凭证字典
|
||||
:return: 标准化后的 credential_id
|
||||
:raises ValueError: 如果凭证无效
|
||||
"""
|
||||
credential_id_raw = credential.get('id') or credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise ValueError("无效的凭证")
|
||||
return PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
|
||||
def _verify_passkey_and_update(
|
||||
credential: dict,
|
||||
challenge: str,
|
||||
passkey: PassKey
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
验证 PassKey 并更新使用时间和签名计数
|
||||
|
||||
:param credential: 凭证字典
|
||||
:param challenge: 挑战值
|
||||
:param passkey: PassKey 对象
|
||||
:return: (验证是否成功, 新的签名计数)
|
||||
"""
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_challenge=challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if success:
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
return success, new_sign_count
|
||||
|
||||
|
||||
async def _check_user_has_passkey(db: AsyncSession, user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否有 PassKey
|
||||
|
||||
:param db: 数据库会话
|
||||
:param user_id: 用户 ID
|
||||
:return: 是否有 PassKey
|
||||
"""
|
||||
return bool(await PassKey.async_get_by_user_id(db=db, user_id=user_id))
|
||||
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
@@ -55,7 +124,7 @@ async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) ->
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id))
|
||||
has_passkey = await _check_user_has_passkey(db, user.id)
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
@@ -93,7 +162,7 @@ async def otp_disable(
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,不允许关闭 OTP
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id))
|
||||
has_passkey = await _check_user_has_passkey(db, current_user.id)
|
||||
if has_passkey:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
@@ -147,13 +216,7 @@ def passkey_register_start(
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
] if existing_passkeys else None
|
||||
existing_credentials = _build_credential_list(existing_passkeys) if existing_passkeys else None
|
||||
|
||||
# 生成注册选项
|
||||
options_json, challenge = PassKeyHelper.generate_registration_options(
|
||||
@@ -233,26 +296,15 @@ def passkey_authenticate_start(
|
||||
# 如果指定了用户名,只允许该用户的PassKey
|
||||
if passkey_req.username:
|
||||
user = User.get_by_name(db=None, name=passkey_req.username)
|
||||
if not user:
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if user else None
|
||||
|
||||
if not user or not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="用户不存在"
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id)
|
||||
if not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="该用户未注册通行密钥"
|
||||
)
|
||||
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
]
|
||||
|
||||
existing_credentials = _build_credential_list(existing_passkeys)
|
||||
|
||||
# 生成认证选项
|
||||
options_json, challenge = PassKeyHelper.generate_authentication_options(
|
||||
@@ -270,7 +322,7 @@ def passkey_authenticate_start(
|
||||
logger.error(f"生成PassKey认证选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成认证选项失败: {str(e)}"
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
|
||||
@@ -280,37 +332,28 @@ def passkey_authenticate_finish(
|
||||
) -> Any:
|
||||
"""完成 PassKey 认证 - 验证凭证并返回 token"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise HTTPException(status_code=400, detail="无效的凭证")
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey认证失败,提供的凭证无效: {e}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey
|
||||
# 查找PassKey并获取用户
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey:
|
||||
raise HTTPException(status_code=401, detail="通行密钥不存在或已失效")
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id) if passkey else None
|
||||
if not passkey or not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 获取用户
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="用户不存在或已禁用")
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail="通行密钥验证失败")
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
@@ -339,7 +382,7 @@ def passkey_authenticate_finish(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey认证失败: {e}")
|
||||
raise HTTPException(status_code=401, detail=f"认证失败: {str(e)}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
|
||||
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
|
||||
@@ -413,16 +456,12 @@ def passkey_verify_mfa(
|
||||
) -> Any:
|
||||
"""使用 PassKey 进行二次验证(MFA)"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="无效的凭证"
|
||||
)
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey二次验证失败,提供的凭证无效: {e}")
|
||||
return schemas.Response(success=False, message="验证失败")
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
@@ -432,12 +471,11 @@ def passkey_verify_mfa(
|
||||
message="通行密钥不存在或不属于当前用户"
|
||||
)
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -446,9 +484,6 @@ def passkey_verify_mfa(
|
||||
message="通行密钥验证失败"
|
||||
)
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
return schemas.Response(
|
||||
@@ -459,5 +494,5 @@ def passkey_verify_mfa(
|
||||
logger.error(f"PassKey二次验证失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"验证失败: {str(e)}"
|
||||
message="验证失败"
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from typing import List, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.ai_recommend import AIRecommendChain
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.log import logger
|
||||
from app.schemas import MediaRecognizeConvertEventData
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
|
||||
@@ -36,6 +38,9 @@ async def search_by_id(mediaid: str,
|
||||
"""
|
||||
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
|
||||
"""
|
||||
# 取消正在运行的AI推荐(会清除数据库缓存)
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
if mtype:
|
||||
media_type = MediaType(mtype)
|
||||
else:
|
||||
@@ -159,6 +164,9 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
"""
|
||||
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
|
||||
"""
|
||||
# 取消正在运行的AI推荐并清除数据库缓存
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=[int(site) for site in sites.split(",") if site] if sites else None,
|
||||
@@ -167,3 +175,87 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
if not torrents:
|
||||
return schemas.Response(success=False, message="未搜索到任何资源")
|
||||
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
|
||||
|
||||
|
||||
@router.post("/recommend", summary="AI推荐资源", response_model=schemas.Response)
|
||||
async def recommend_search_results(
|
||||
filtered_indices: Optional[List[int]] = Body(None, embed=True, description="筛选后的索引列表"),
|
||||
check_only: bool = Body(False, embed=True, description="仅检查状态,不启动新任务"),
|
||||
force: bool = Body(False, embed=True, description="强制重新推荐,清除旧结果"),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
AI推荐资源 - 轮询接口
|
||||
前端轮询此接口,发送筛选后的索引(如果有筛选)
|
||||
后端根据请求变化自动取消旧任务并启动新任务
|
||||
|
||||
参数:
|
||||
- filtered_indices: 筛选后的索引列表(可选,为空或不提供时使用所有结果)
|
||||
- check_only: 仅检查状态(首次打开页面时使用,避免触发不必要的重新推理)
|
||||
- force: 强制重新推荐(清除旧结果并重新启动)
|
||||
|
||||
返回数据结构:
|
||||
{
|
||||
"success": bool,
|
||||
"message": string, // 错误信息(仅在错误时存在)
|
||||
"data": {
|
||||
"status": string, // 状态: disabled | idle | running | completed | error
|
||||
"results": array // 推荐结果(仅status=completed时存在)
|
||||
}
|
||||
}
|
||||
"""
|
||||
# 从缓存获取上次搜索结果
|
||||
results = await SearchChain().async_last_search_results() or []
|
||||
if not results:
|
||||
return schemas.Response(success=False, message="没有可用的搜索结果", data={
|
||||
"status": "error"
|
||||
})
|
||||
|
||||
recommend_chain = AIRecommendChain()
|
||||
|
||||
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
|
||||
if force:
|
||||
# 检查功能是否启用
|
||||
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "disabled"
|
||||
})
|
||||
logger.info("收到新推荐请求,清除旧结果并启动新任务")
|
||||
recommend_chain.cancel_ai_recommend()
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 直接返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测)
|
||||
if check_only:
|
||||
# 返回当前运行状态,不做任何任务启动或取消操作
|
||||
current_status = recommend_chain.get_current_status_only()
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if current_status.get("status") == "error":
|
||||
error_msg = current_status.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=current_status)
|
||||
return schemas.Response(success=True, data=current_status)
|
||||
|
||||
# 获取当前状态(会检测请求是否变化)
|
||||
status_data = recommend_chain.get_status(filtered_indices, len(results))
|
||||
|
||||
# 如果功能未启用,直接返回禁用状态
|
||||
if status_data.get("status") == "disabled":
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
# 如果是空闲状态,启动新任务
|
||||
if status_data["status"] == "idle":
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 立即返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if status_data.get("status") == "error":
|
||||
error_msg = status_data.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=status_data)
|
||||
|
||||
# 返回当前状态
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
@@ -130,28 +130,52 @@ async def cache_img(
|
||||
def get_global_setting(token: str):
|
||||
"""
|
||||
查询非敏感系统设置(默认鉴权)
|
||||
仅包含登录前UI初始化必需的字段
|
||||
"""
|
||||
if token != "moviepilot":
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# 白名单模式,仅包含前端业务逻辑必需的字段
|
||||
# 白名单模式,仅包含登录前UI初始化必需的字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"TMDB_IMAGE_DOMAIN",
|
||||
"GLOBAL_IMAGE_CACHE",
|
||||
"ADVANCED_MODE",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE"
|
||||
}
|
||||
)
|
||||
# 追加版本信息(用于版本检查)
|
||||
info.update({
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
|
||||
|
||||
@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response)
|
||||
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询用户相关系统设置(登录后获取)
|
||||
包含业务功能相关的配置和用户权限信息
|
||||
"""
|
||||
# 业务功能相关的配置字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED"
|
||||
}
|
||||
)
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
info["AI_RECOMMEND_ENABLED"] = False
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
|
||||
@@ -4,6 +4,7 @@ import pickle
|
||||
import traceback
|
||||
from abc import ABCMeta
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Tuple, List, Set, Union, Dict
|
||||
|
||||
@@ -849,6 +850,8 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
@@ -932,6 +935,8 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
|
||||
318
app/chain/ai_recommend.py
Normal file
318
app/chain/ai_recommend.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import re
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class AIRecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
AI推荐处理链,单例运行
|
||||
用于基于搜索结果的AI智能推荐
|
||||
"""
|
||||
|
||||
# 缓存文件名
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
# AI推荐状态
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_request_hash: Optional[str] = None # 当前请求的哈希值
|
||||
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存(索引列表)
|
||||
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
|
||||
|
||||
@staticmethod
|
||||
def _calculate_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算请求的哈希值,用于判断请求是否变化
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
def _build_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典
|
||||
:return: 状态字典
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if self._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
# 尝试从数据库加载缓存
|
||||
if self._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
self._ai_recommend_result = cached_indices
|
||||
|
||||
# 只要有结果,始终返回completed状态和数据
|
||||
if self._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": self._ai_recommend_result}
|
||||
|
||||
if self._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": self._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前状态(不校验hash,用于check_only模式)
|
||||
"""
|
||||
return self._build_status()
|
||||
|
||||
def get_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态并检查请求是否变化(用于首次请求或force模式)
|
||||
如果请求变化(筛选条件变化),返回idle状态
|
||||
"""
|
||||
# 计算当前请求的hash
|
||||
request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 检查请求是否变化
|
||||
is_same_request = request_hash == self._current_request_hash
|
||||
|
||||
# 如果请求变化了(筛选条件改变),返回idle状态
|
||||
if not is_same_request:
|
||||
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
|
||||
|
||||
# 请求未变化,返回当前实际状态
|
||||
return self._build_status()
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
|
||||
"""
|
||||
AI推荐
|
||||
:param items: 候选资源列表(JSON字符串格式)
|
||||
:param preference: 用户偏好(可选)
|
||||
:return: AI返回的推荐结果
|
||||
"""
|
||||
# 设置运行状态
|
||||
self._ai_recommend_running = True
|
||||
try:
|
||||
# 导入LLMHelper
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
# 获取LLM实例
|
||||
llm = LLMHelper.get_llm()
|
||||
|
||||
# 构建提示词
|
||||
user_preference = (
|
||||
preference
|
||||
or settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
|
||||
# 添加指令
|
||||
instruction = """
|
||||
Task: Select the best matching items from the list based on user preferences.
|
||||
|
||||
Each item contains:
|
||||
- index: Item number
|
||||
- title: Full torrent title
|
||||
- size: File size
|
||||
- seeders: Number of seeders
|
||||
|
||||
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
|
||||
"""
|
||||
message = (
|
||||
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
|
||||
+ "\n".join(items)
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
response = await llm.ainvoke(message)
|
||||
return response.content
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"AI推荐配置错误: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
finally:
|
||||
# 清除运行状态
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
def is_ai_recommend_running(self) -> bool:
|
||||
"""
|
||||
检查AI推荐是否正在运行
|
||||
"""
|
||||
return self._ai_recommend_running
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消正在运行的AI推荐任务
|
||||
"""
|
||||
if self._ai_recommend_task and not self._ai_recommend_task.done():
|
||||
self._ai_recommend_task.cancel()
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
self._current_request_hash = None
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务
|
||||
:param filtered_indices: 筛选后的索引列表
|
||||
:param search_results_count: 搜索结果总数
|
||||
:param results: 搜索结果列表
|
||||
"""
|
||||
# 防护检查:确保AI推荐功能已启用
|
||||
if not self.is_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
# 计算新请求的哈希值
|
||||
new_request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 如果请求变化了,取消旧任务
|
||||
if new_request_hash != self._current_request_hash:
|
||||
self.cancel_ai_recommend()
|
||||
|
||||
# 更新请求哈希值
|
||||
self._current_request_hash = new_request_hash
|
||||
|
||||
# 重置状态
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
|
||||
# 启动新任务
|
||||
async def run_recommend():
|
||||
# 获取当前任务对象,用于在finally中比对
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
self._ai_recommend_running = True
|
||||
|
||||
# 准备数据
|
||||
items = []
|
||||
valid_indices = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
|
||||
if filtered_indices is not None and len(filtered_indices) > 0:
|
||||
results_to_process = [
|
||||
results[i]
|
||||
for i in filtered_indices
|
||||
if 0 <= i < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for i, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(i)
|
||||
|
||||
item_info = {
|
||||
"index": i,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
if not items:
|
||||
self._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
# 调用AI推荐
|
||||
ai_response = await self.async_ai_recommend(items)
|
||||
|
||||
# 解析AI返回的索引
|
||||
try:
|
||||
# 使用正则提取JSON数组(非贪婪模式,避免匹配多个数组)
|
||||
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(ai_response)
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
# 映射回原始索引
|
||||
if filtered_indices:
|
||||
original_indices = [
|
||||
filtered_indices[valid_indices[i]]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= filtered_indices[valid_indices[i]] < len(results)
|
||||
]
|
||||
else:
|
||||
original_indices = [
|
||||
valid_indices[i]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= valid_indices[i] < len(results)
|
||||
]
|
||||
|
||||
# 只返回索引列表,不返回完整数据
|
||||
self._ai_recommend_result = original_indices
|
||||
|
||||
# 保存到数据库
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
|
||||
)
|
||||
self._ai_recommend_error = str(e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"AI推荐任务失败: {e}")
|
||||
self._ai_recommend_error = str(e)
|
||||
finally:
|
||||
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
|
||||
# 如果任务被取消并启动了新任务,self._ai_recommend_task 已经指向新任务,不应重置
|
||||
if self._ai_recommend_task == current_task:
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
# 创建并启动任务
|
||||
self._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
@@ -292,6 +292,10 @@ class DownloadChain(ChainBase):
|
||||
|
||||
# 登记下载记录
|
||||
downloadhis = DownloadHistoryOper()
|
||||
# 获取应用的识别词(如果有)
|
||||
custom_words_str = None
|
||||
if hasattr(_meta, 'apply_words') and _meta.apply_words:
|
||||
custom_words_str = '\n'.join(_meta.apply_words)
|
||||
downloadhis.add(
|
||||
path=download_path.as_posix(),
|
||||
type=_media.type.value,
|
||||
@@ -315,6 +319,7 @@ class DownloadChain(ChainBase):
|
||||
date=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
media_category=_media.category,
|
||||
episode_group=_media.episode_group,
|
||||
custom_words=custom_words_str,
|
||||
note={"source": source}
|
||||
)
|
||||
|
||||
|
||||
@@ -315,21 +315,6 @@ class MediaChain(ChainBase):
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_bluray_folder(fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断是否为原盘目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
# 蓝光原盘目录必备的文件或文件夹
|
||||
required_files = ['BDMV', 'CERTIFICATE']
|
||||
# 检查目录下是否存在所需文件或文件夹
|
||||
for item in StorageChain().list_files(fileitem):
|
||||
if item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
@eventmanager.register(EventType.MetadataScrape)
|
||||
def scrape_metadata_event(self, event: Event):
|
||||
"""
|
||||
@@ -370,7 +355,7 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
if file_list:
|
||||
# 如果是BDMV原盘目录,只对根目录进行刮削,不处理子目录
|
||||
if self.is_bluray_folder(fileitem):
|
||||
if storagechain.is_bluray_folder(fileitem):
|
||||
logger.info(f"检测到BDMV原盘目录,只对根目录进行刮削:{fileitem.path}")
|
||||
self.scrape_metadata(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
@@ -563,10 +548,23 @@ class MediaChain(ChainBase):
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 电影目录
|
||||
if recursive:
|
||||
# 处理文件
|
||||
if self.is_bluray_folder(fileitem):
|
||||
# 原盘目录
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
is_bluray_folder = storagechain.contains_bluray_subdirectories(files)
|
||||
if recursive and not is_bluray_folder:
|
||||
# 处理非原盘目录内的文件
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 生成目录内图片文件
|
||||
if init_folder:
|
||||
if is_bluray_folder:
|
||||
# 检查电影NFO开关
|
||||
if scraping_switchs.get('movie_nfo', True):
|
||||
nfo_path = filepath / (filepath.name + ".nfo")
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
|
||||
@@ -581,20 +579,6 @@ class MediaChain(ChainBase):
|
||||
logger.info(f"已存在nfo文件:{nfo_path}")
|
||||
else:
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 处理目录内的文件
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 生成目录内图片文件
|
||||
if init_folder:
|
||||
# 图片
|
||||
image_dict = self.metadata_img(mediainfo=mediainfo)
|
||||
if image_dict:
|
||||
@@ -618,7 +602,7 @@ class MediaChain(ChainBase):
|
||||
should_scrape = True # 未知类型默认刮削
|
||||
|
||||
if should_scrape:
|
||||
image_path = filepath.with_name(image_name)
|
||||
image_path = filepath / image_name
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 流式下载图片并直接保存
|
||||
@@ -681,7 +665,11 @@ class MediaChain(ChainBase):
|
||||
if recursive:
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir" and not file.name.lower().startswith("season"):
|
||||
if (
|
||||
file.type == "dir"
|
||||
and file.name not in settings.RENAME_FORMAT_S0_NAMES
|
||||
and not file.name.lower().startswith("season")
|
||||
):
|
||||
# 电视剧不处理非季子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
@@ -691,11 +679,19 @@ class MediaChain(ChainBase):
|
||||
overwrite=overwrite)
|
||||
# 生成目录的nfo和图片
|
||||
if init_folder:
|
||||
# TODO 目前的刮削是假定电视剧目录结构符合:/剧集根目录/季目录/剧集文件
|
||||
# 其中季目录应符合`Season 数字`等明确的季命名,不能用季标题
|
||||
# 例如:/Torchwood (2006)/Miracle Day/Torchwood (2006) S04E01.mkv
|
||||
# 当刮削到`Miracle Day`目录时,会误判其为剧集根目录
|
||||
# 识别文件夹名称
|
||||
season_meta = MetaInfo(filepath.name)
|
||||
# 当前文件夹为Specials或者SPs时,设置为S0
|
||||
if filepath.name in settings.RENAME_FORMAT_S0_NAMES:
|
||||
season_meta.begin_season = 0
|
||||
elif season_meta.name and season_meta.begin_season is not None:
|
||||
# 当前目录含有非季目录的名称,但却有季信息(通常是被辅助识别词指定了)
|
||||
# 这种情况应该是剧集根目录,不能按季目录刮削,否则会导致`season_poster`的路径错误 详见issue#5373
|
||||
season_meta.begin_season = None
|
||||
if season_meta.begin_season is not None:
|
||||
# 检查季NFO开关
|
||||
if scraping_switchs.get('season_nfo', True):
|
||||
@@ -765,7 +761,8 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
logger.info(f"季图片刮削已关闭,跳过:{image_name}")
|
||||
# 判断当前目录是不是剧集根目录
|
||||
if not season_meta.season:
|
||||
elif season_meta.name:
|
||||
# 不含季信息(包括特别季)但含有名称的,可以认为是剧集根目录
|
||||
# 检查电视剧NFO开关
|
||||
if scraping_switchs.get('tv_nfo', True):
|
||||
# 是否已存在
|
||||
|
||||
@@ -195,10 +195,14 @@ class MessageChain(ChainBase):
|
||||
if text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
# 选择项目
|
||||
if not cache_data \
|
||||
or not cache_data.get('items') \
|
||||
if not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
@@ -370,12 +374,13 @@ class MessageChain(ChainBase):
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
if _current_page == 0:
|
||||
# 第一页
|
||||
@@ -422,12 +427,13 @@ class MessageChain(ChainBase):
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
|
||||
@@ -29,6 +29,7 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
|
||||
__result_temp_file = "__search_result__"
|
||||
__ai_result_temp_file = "__ai_search_result__"
|
||||
|
||||
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
@@ -98,6 +99,18 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_ai_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次AI推荐结果
|
||||
"""
|
||||
return await self.async_load_cache(self.__ai_result_temp_file)
|
||||
|
||||
async def async_save_ai_results(self, results: List[Context]):
|
||||
"""
|
||||
异步保存AI推荐结果
|
||||
"""
|
||||
await self.async_save_cache(results, self.__ai_result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
|
||||
|
||||
@@ -44,6 +44,7 @@ class SiteChain(ChainBase):
|
||||
"star-space.net": self.__indexphp_test,
|
||||
"yemapt.org": self.__yema_test,
|
||||
"hddolby.com": self.__hddolby_test,
|
||||
"rousi.pro": self.__rousi_test,
|
||||
}
|
||||
|
||||
def refresh_userdata(self, site: dict = None) -> Optional[SiteUserData]:
|
||||
@@ -249,6 +250,32 @@ class SiteChain(ChainBase):
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __rousi_test(site: Site) -> Tuple[bool, str]:
|
||||
"""
|
||||
判断站点是否已经登陆:rousi
|
||||
"""
|
||||
url = f"https://{StringUtils.get_url_domain(site.url)}/api/v1/profile"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {site.apikey}",
|
||||
}
|
||||
res = RequestUtils(
|
||||
headers=headers,
|
||||
proxies=settings.PROXY if site.proxy else None,
|
||||
timeout=site.timeout or 15
|
||||
).get_res(url=url)
|
||||
if res is None:
|
||||
return False, "无法打开网站!"
|
||||
if res.status_code == 200:
|
||||
user_info = res.json()
|
||||
if user_info and user_info.get("code") == 0:
|
||||
return True, "连接成功"
|
||||
return False, "APIKEY已过期"
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __parse_favicon(url: str, cookie: str, ua: str) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
@@ -462,20 +489,18 @@ class SiteChain(ChainBase):
|
||||
logger.warn(f"站点 {domain} 索引器不存在!")
|
||||
return
|
||||
# 查询站点图标
|
||||
site_icon = siteoper.get_icon_by_domain(domain)
|
||||
if not site_icon or not site_icon.base64:
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
|
||||
@eventmanager.register(EventType.SiteUpdated)
|
||||
def clear_site_data(self, event: Event):
|
||||
|
||||
@@ -133,22 +133,29 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("support_transtype", storage=storage)
|
||||
|
||||
def is_bluray_folder(self, fileitem: Optional[schemas.FileItem]) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
return self.contains_bluray_subdirectories(self.list_files(fileitem))
|
||||
|
||||
@staticmethod
|
||||
def contains_bluray_subdirectories(fileitems: Optional[List[schemas.FileItem]]) -> bool:
|
||||
"""
|
||||
判断是否包含蓝光必备的文件夹
|
||||
"""
|
||||
required_files = ("BDMV", "CERTIFICATE")
|
||||
for item in fileitems or []:
|
||||
if item.type == "dir" and item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
|
||||
"""
|
||||
删除媒体文件,以及不含媒体文件的目录
|
||||
"""
|
||||
|
||||
def __is_bluray_dir(_fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
_dir_files = self.list_files(fileitem=_fileitem, recursion=False)
|
||||
if _dir_files:
|
||||
for _f in _dir_files:
|
||||
if _f.type == "dir" and _f.name in ["BDMV", "CERTIFICATE"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
|
||||
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
|
||||
if len(fileitem_path.parts) <= 2:
|
||||
@@ -156,7 +163,7 @@ class StorageChain(ChainBase):
|
||||
return False
|
||||
if fileitem.type == "dir":
|
||||
# 本身是目录
|
||||
if __is_bluray_dir(fileitem):
|
||||
if self.is_bluray_folder(fileitem):
|
||||
logger.warn(f"正在删除蓝光原盘目录:【{fileitem.storage}】{fileitem.path}")
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
|
||||
@@ -1635,7 +1635,7 @@ class SubscribeChain(ChainBase):
|
||||
info = schemas.SubscribeEpisodeInfo()
|
||||
info.title = episode.name
|
||||
info.description = episode.overview
|
||||
info.backdrop = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/w500${episode.still_path}"
|
||||
info.backdrop = settings.TMDB_IMAGE_URL(episode.still_path, "w500")
|
||||
episodes[episode.episode_number] = info
|
||||
elif subscribe.type == MediaType.TV.value:
|
||||
# 根据开始结束集计算集信息
|
||||
|
||||
@@ -376,7 +376,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
self._transfer_interval = 15
|
||||
# 事件管理器
|
||||
self.jobview = JobManager()
|
||||
# 车移成功的文件清单
|
||||
# 转移成功的文件清单
|
||||
self._success_target_files: Dict[str, List[str]] = {}
|
||||
# 启动整理任务
|
||||
self.__init()
|
||||
@@ -560,8 +560,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
processed_num = 0
|
||||
# 失败数量
|
||||
fail_num = 0
|
||||
# 已完成文件
|
||||
finished_files = []
|
||||
|
||||
progress = ProgressHelper(ProgressKey.FileTransfer)
|
||||
|
||||
@@ -594,10 +592,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=processed_num / total_num * 100,
|
||||
text=__process_msg,
|
||||
data={
|
||||
"current": Path(fileitem.path).as_posix(),
|
||||
"finished": finished_files
|
||||
})
|
||||
data={})
|
||||
# 整理
|
||||
state, err_msg = self.__handle_transfer(task=task, callback=item.callback)
|
||||
if not state:
|
||||
@@ -605,7 +600,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
fail_num += 1
|
||||
# 更新进度
|
||||
processed_num += 1
|
||||
finished_files.append(Path(fileitem.path).as_posix())
|
||||
__process_msg = f"{fileitem.name} 整理完成"
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=(processed_num / total_num) * 100,
|
||||
@@ -873,7 +867,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
state, errmsg = self.do_transfer(
|
||||
fileitem=FileItem(
|
||||
storage="local",
|
||||
path=file_path.as_posix(),
|
||||
path=file_path.as_posix() + ("/" if file_path.is_dir() else ""),
|
||||
type="dir" if not file_path.is_file() else "file",
|
||||
name=file_path.name,
|
||||
size=file_path.stat().st_size,
|
||||
@@ -908,16 +902,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
storagechain = StorageChain()
|
||||
|
||||
def __contains_bluray_sub(_fileitems: List[FileItem]) -> bool:
|
||||
"""
|
||||
判断是否包含蓝光子目录
|
||||
"""
|
||||
if _fileitems:
|
||||
for sub in _fileitems:
|
||||
if sub.type == "dir" and sub.name in ["BDMV", "CERTIFICATE"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __is_bluray_sub(_path: str) -> bool:
|
||||
"""
|
||||
判断是否蓝光原盘目录内的子目录或文件
|
||||
@@ -933,9 +917,12 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
return storagechain.get_file_item(storage=_storage, path=p.parent)
|
||||
return None
|
||||
|
||||
if not storagechain.get_item(fileitem):
|
||||
latest_fileitem = storagechain.get_item(fileitem)
|
||||
if not latest_fileitem:
|
||||
logger.warn(f"目录或文件不存在:{fileitem.path}")
|
||||
return []
|
||||
# 确保从历史记录重新整理时 能获得最新的源文件大小、修改日期等
|
||||
fileitem = latest_fileitem
|
||||
|
||||
# 蓝光原盘子目录或文件
|
||||
if __is_bluray_sub(fileitem.path):
|
||||
@@ -949,7 +936,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
|
||||
# 蓝光原盘根目录
|
||||
sub_items = storagechain.list_files(fileitem) or []
|
||||
if __contains_bluray_sub(sub_items):
|
||||
if storagechain.contains_bluray_subdirectories(sub_items):
|
||||
return [(fileitem, True)]
|
||||
|
||||
# 需要整理的文件项列表
|
||||
@@ -1093,9 +1080,26 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
err_msgs.append(f"{file_item.name} 已整理过")
|
||||
continue
|
||||
|
||||
# 提前获取下载历史,以便获取自定义识别词
|
||||
download_history = None
|
||||
downloadhis = DownloadHistoryOper()
|
||||
if bluray_dir:
|
||||
# 蓝光原盘,按目录名查询
|
||||
download_history = downloadhis.get_by_path(file_path.as_posix())
|
||||
else:
|
||||
# 按文件全路径查询
|
||||
download_file = downloadhis.get_file_by_fullpath(file_path.as_posix())
|
||||
if download_file:
|
||||
download_history = downloadhis.get_by_hash(download_file.download_hash)
|
||||
|
||||
# 获取自定义识别词
|
||||
custom_words_list = None
|
||||
if download_history and download_history.custom_words:
|
||||
custom_words_list = download_history.custom_words.split('\n')
|
||||
|
||||
if not meta:
|
||||
# 文件元数据
|
||||
file_meta = MetaInfoPath(file_path)
|
||||
# 文件元数据(传入自定义识别词)
|
||||
file_meta = MetaInfoPath(file_path, custom_words=custom_words_list)
|
||||
else:
|
||||
file_meta = meta
|
||||
|
||||
@@ -1121,18 +1125,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
if end_ep is not None:
|
||||
file_meta.end_episode = end_ep
|
||||
|
||||
# 根据父路径获取下载历史
|
||||
download_history = None
|
||||
downloadhis = DownloadHistoryOper()
|
||||
if bluray_dir:
|
||||
# 蓝光原盘,按目录名查询
|
||||
download_history = downloadhis.get_by_path(file_path.as_posix())
|
||||
else:
|
||||
# 按文件全路径查询
|
||||
download_file = downloadhis.get_file_by_fullpath(file_path.as_posix())
|
||||
if download_file:
|
||||
download_history = downloadhis.get_by_hash(download_file.download_hash)
|
||||
|
||||
# 获取下载Hash
|
||||
if download_history and (not downloader or not download_hash):
|
||||
downloader = download_history.downloader
|
||||
|
||||
@@ -278,7 +278,7 @@ class ConfigModel(BaseModel):
|
||||
# 搜索多个名称
|
||||
SEARCH_MULTIPLE_NAME: bool = False
|
||||
# 最大搜索名称数量
|
||||
MAX_SEARCH_NAME_LIMIT: int = 2
|
||||
MAX_SEARCH_NAME_LIMIT: int = 3
|
||||
|
||||
# ==================== 下载配置 ====================
|
||||
# 种子标签
|
||||
@@ -439,6 +439,12 @@ class ConfigModel(BaseModel):
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 1
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
# 是否启用AI推荐
|
||||
AI_RECOMMEND_ENABLED: bool = False
|
||||
# AI推荐用户偏好
|
||||
AI_RECOMMEND_USER_PREFERENCE: str = ""
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@@ -843,6 +849,18 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
rename_format = re.sub(r'/+', '/', rename_format)
|
||||
return rename_format.strip("/")
|
||||
|
||||
def TMDB_IMAGE_URL(self, file_path: str, file_size: str = "original") -> str:
|
||||
"""
|
||||
获取TMDB图片网址
|
||||
|
||||
:param file_path: TMDB API返回的xxx_path
|
||||
:param file_size: 图片大小,例如:'original', 'w500' 等
|
||||
:return: 图片的完整URL
|
||||
"""
|
||||
return (
|
||||
f"https://{self.TMDB_IMAGE_DOMAIN}/t/p/{file_size}/{file_path.removeprefix('/')}"
|
||||
)
|
||||
|
||||
|
||||
# 实例化配置
|
||||
settings = Settings()
|
||||
|
||||
@@ -479,11 +479,11 @@ class MediaInfo:
|
||||
self.episode_groups = info.pop("episode_groups").get("results") or []
|
||||
|
||||
# 海报
|
||||
if info.get('poster_path'):
|
||||
self.poster_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('poster_path')}"
|
||||
if path := info.get('poster_path'):
|
||||
self.poster_path = settings.TMDB_IMAGE_URL(path)
|
||||
# 背景
|
||||
if info.get('backdrop_path'):
|
||||
self.backdrop_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('backdrop_path')}"
|
||||
if path := info.get('backdrop_path'):
|
||||
self.backdrop_path = settings.TMDB_IMAGE_URL(path)
|
||||
# 导演和演员
|
||||
self.directors, self.actors = __directors_actors(info)
|
||||
# 别名和译名
|
||||
|
||||
@@ -62,21 +62,24 @@ def MetaInfo(title: str, subtitle: Optional[str] = None, custom_words: List[str]
|
||||
return meta
|
||||
|
||||
|
||||
def MetaInfoPath(path: Path) -> MetaBase:
|
||||
def MetaInfoPath(path: Path, custom_words: List[str] = None) -> MetaBase:
|
||||
"""
|
||||
根据路径识别元数据
|
||||
:param path: 路径
|
||||
:param custom_words: 自定义识别词列表
|
||||
"""
|
||||
# 文件元数据,不包含后缀
|
||||
file_meta = MetaInfo(title=path.name)
|
||||
file_meta = MetaInfo(title=path.name, custom_words=custom_words)
|
||||
# 上级目录元数据
|
||||
dir_meta = MetaInfo(title=path.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
dir_meta = MetaInfo(title=path.parent.name, custom_words=custom_words)
|
||||
if file_meta.type == MediaType.TV or dir_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
# 上上级目录元数据
|
||||
root_meta = MetaInfo(title=path.parent.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
root_meta = MetaInfo(title=path.parent.parent.name, custom_words=custom_words)
|
||||
if file_meta.type == MediaType.TV or root_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
return file_meta
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, AP
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
@@ -24,7 +25,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
# OAuth2PasswordBearer 用于 JWT Token 认证
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
oauth2_scheme_manual_error = OAuth2PasswordBearer(
|
||||
auto_error=False, # 禁用自动错误处理,用以支持API令牌鉴权
|
||||
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
|
||||
)
|
||||
|
||||
@@ -41,6 +43,58 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
|
||||
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
) -> str | None:
|
||||
"""
|
||||
从 URL 查询参数中获取 API Token
|
||||
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
|
||||
:return: 返回获取到的 API Token,若无则返回 None
|
||||
"""
|
||||
return token_query
|
||||
|
||||
|
||||
def __get_api_key(
|
||||
key_query: Annotated[str | None, Security(api_key_query)] = None,
|
||||
key_header: Annotated[str | None, Security(api_key_header)] = None
|
||||
) -> str | None:
|
||||
"""
|
||||
从 URL 查询参数或请求头部获取 API Key,优先使用请求头
|
||||
:param key_query: URL 中的 `apikey` 查询参数
|
||||
:param key_header: 请求头中的 `X-API-KEY` 参数
|
||||
:return: 返回从 URL 或请求头中获取的 API Key,若无则返回 None
|
||||
"""
|
||||
return key_header or key_query # 首选请求头
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl=600)
|
||||
def __create_superuser_token_payload() -> schemas.TokenPayload:
|
||||
"""
|
||||
创建管理员用户的TokenPayload
|
||||
|
||||
:return: 管理员TokenPayload
|
||||
"""
|
||||
# 延迟导入
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=no-name-in-module
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
|
||||
user = UserOper().get_by_name(settings.SUPERUSER)
|
||||
if not user or not user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户权限不足",
|
||||
)
|
||||
return schemas.TokenPayload(
|
||||
sub=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
level=SitesHelper().auth_level,
|
||||
purpose="authentication",
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
userid: Union[str, Any],
|
||||
username: str,
|
||||
@@ -176,23 +230,43 @@ def __verify_token(token: str, purpose: Optional[str] = "authentication") -> sch
|
||||
def verify_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
token: Annotated[str, Security(oauth2_scheme)]
|
||||
jwt_token: Annotated[str | None, Security(oauth2_scheme_manual_error)],
|
||||
api_key: Annotated[str | None, Security(__get_api_key)],
|
||||
api_token: Annotated[str | None, Security(__get_api_token)],
|
||||
) -> schemas.TokenPayload:
|
||||
"""
|
||||
验证 JWT 令牌并自动处理 resource_token 写入
|
||||
|
||||
如果缺少JWT令牌再尝试用API令牌鉴权
|
||||
|
||||
:param request: 请求对象,用于访问 Cookie 和请求信息
|
||||
:param response: 响应对象,用于设置 Cookie
|
||||
:param token: 从 Authorization 头部获取的 JWT 令牌
|
||||
:param jwt_token: 从 Authorization 头部获取的 JWT 令牌
|
||||
:param api_key: 从 查询参数`apikey` 或 请求头`X-API-KEY` 获取 API Token
|
||||
:param api_token: 从 查询参数`token` 获取 API Token
|
||||
:return: 解析后的 TokenPayload
|
||||
:raises HTTPException: 如果令牌无效或用途不匹配
|
||||
"""
|
||||
# 验证并解析 JWT 认证令牌
|
||||
payload = __verify_token(token=token, purpose="authentication")
|
||||
if jwt_token:
|
||||
# 验证并解析 JWT 认证令牌
|
||||
payload = __verify_token(token=jwt_token, purpose="authentication")
|
||||
|
||||
# 如果没有 resource_token,生成并写入到 Cookie
|
||||
__set_or_refresh_resource_token_cookie(request, response, payload)
|
||||
# 如果没有 resource_token,生成并写入到 Cookie
|
||||
__set_or_refresh_resource_token_cookie(request, response, payload)
|
||||
|
||||
return payload
|
||||
return payload
|
||||
elif api_key:
|
||||
verify_apikey(api_key)
|
||||
return __create_superuser_token_payload()
|
||||
elif api_token:
|
||||
verify_apitoken(api_token)
|
||||
return __create_superuser_token_payload()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def verify_resource_token(
|
||||
@@ -208,31 +282,7 @@ def verify_resource_token(
|
||||
return __verify_token(token=resource_token, purpose="resource")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
) -> str:
|
||||
"""
|
||||
从 URL 查询参数中获取 API Token
|
||||
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
|
||||
:return: 返回获取到的 API Token,若无则返回 None
|
||||
"""
|
||||
return token_query
|
||||
|
||||
|
||||
def __get_api_key(
|
||||
key_query: Annotated[str | None, Security(api_key_query)] = None,
|
||||
key_header: Annotated[str | None, Security(api_key_header)] = None
|
||||
) -> str:
|
||||
"""
|
||||
从 URL 查询参数或请求头部获取 API Key,优先使用 URL 参数
|
||||
:param key_query: URL 中的 `apikey` 查询参数
|
||||
:param key_header: 请求头中的 `X-API-KEY` 参数
|
||||
:return: 返回从 URL 或请求头中获取的 API Key,若无则返回 None
|
||||
"""
|
||||
return key_query or key_header
|
||||
|
||||
|
||||
def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
def __verify_key(key: str | None, expected_key: str, key_type: str) -> str:
|
||||
"""
|
||||
通用的 API Key 或 Token 验证函数
|
||||
:param key: 从请求中获取的 API Key 或 Token
|
||||
@@ -241,7 +291,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
:return: 返回校验通过的 API Key 或 Token
|
||||
:raises HTTPException: 如果校验不通过,抛出 401 错误
|
||||
"""
|
||||
if key != expected_key:
|
||||
if not key or key != expected_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"{key_type} 校验不通过"
|
||||
@@ -249,7 +299,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
def verify_apitoken(token: Annotated[str | None, Security(__get_api_token)]) -> str:
|
||||
"""
|
||||
使用 API Token 进行身份认证
|
||||
:param token: API Token,从 URL 查询参数中获取 token=xxx
|
||||
@@ -258,7 +308,7 @@ def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
return __verify_key(token, settings.API_TOKEN, "token")
|
||||
|
||||
|
||||
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
|
||||
def verify_apikey(apikey: Annotated[str | None, Security(__get_api_key)]) -> str:
|
||||
"""
|
||||
使用 API Key 进行身份认证
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx,或请求头中获取 X-API-KEY=xxx
|
||||
|
||||
@@ -55,6 +55,8 @@ class DownloadHistory(Base):
|
||||
media_category = Column(String)
|
||||
# 剧集组
|
||||
episode_group = Column(String)
|
||||
# 自定义识别词(用于整理时应用)
|
||||
custom_words = Column(String)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from app.db import DbOper
|
||||
@@ -17,6 +19,8 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
super().__init__()
|
||||
self.__SYSTEMCONF = {}
|
||||
self._rlock = threading.RLock()
|
||||
self._alock = asyncio.Lock()
|
||||
for item in SystemConfig.list(self._db):
|
||||
self.__SYSTEMCONF[item.key] = item.value
|
||||
|
||||
@@ -29,23 +33,24 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
with self._rlock:
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
conf.create(self._db)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
conf.create(self._db)
|
||||
return True
|
||||
|
||||
async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]:
|
||||
"""
|
||||
@@ -56,22 +61,32 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = await SystemConfig.async_get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
async with self._alock:
|
||||
conf = await SystemConfig.async_get_by_key(self._db, key)
|
||||
# 确定是否需要更新数据库
|
||||
needs_db_update = False
|
||||
if conf:
|
||||
if conf.value != value:
|
||||
needs_db_update = True
|
||||
else: # 记录不存在,总是需要创建/更新
|
||||
needs_db_update = True
|
||||
if not needs_db_update:
|
||||
# 即使数据库值相同,也要确保缓存同步
|
||||
with self._rlock:
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
return None
|
||||
# 执行数据库更新
|
||||
if conf:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
await conf.async_update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
await conf.async_create(self._db)
|
||||
await conf.async_delete(self._db, conf.id)
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
await conf.async_create(self._db)
|
||||
# 数据库更新成功后,再更新缓存
|
||||
with self._rlock:
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
return True
|
||||
|
||||
def get(self, key: Union[str, SystemConfigKey] = None) -> Any:
|
||||
@@ -82,15 +97,17 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
key = key.value
|
||||
if not key:
|
||||
return self.all()
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF.get(key))
|
||||
with self._rlock:
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF.get(key))
|
||||
|
||||
def all(self):
|
||||
"""
|
||||
获取所有系统设置
|
||||
"""
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF)
|
||||
with self._rlock:
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF)
|
||||
|
||||
def delete(self, key: Union[str, SystemConfigKey]) -> bool:
|
||||
"""
|
||||
@@ -98,10 +115,11 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 更新内存
|
||||
self.__SYSTEMCONF.pop(key, None)
|
||||
# 写入数据库
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
with self._rlock:
|
||||
# 更新内存
|
||||
self.__SYSTEMCONF.pop(key, None)
|
||||
# 写入数据库
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
|
||||
@@ -142,19 +142,22 @@ class DirectoryHelper:
|
||||
# 计算重命名中的文件夹层数
|
||||
rename_list = rename_format.split("/")
|
||||
rename_format_level = len(rename_list) - 1
|
||||
# 查找标题参数所在层
|
||||
for level, name in enumerate(rename_list):
|
||||
# 反向查找标题参数所在层
|
||||
for level, name in enumerate(reversed(rename_list)):
|
||||
if level == 0:
|
||||
# 跳过文件名的标题参数
|
||||
continue
|
||||
matchs = JINJA2_VAR_PATTERN.findall(name)
|
||||
if not matchs:
|
||||
continue
|
||||
# 处理特例,有的人重命名的第一层是年份、分辨率
|
||||
if any("title" in m for m in matchs):
|
||||
# 找出含标题的这一层作为媒体根目录
|
||||
rename_format_level -= level
|
||||
# 找出最后一层含有标题参数的目录作为媒体根目录
|
||||
rename_format_level = level
|
||||
break
|
||||
else:
|
||||
# 假定第一层目录是媒体根目录
|
||||
logger.warn(f"重命名格式 {rename_format} 缺少标题参数")
|
||||
logger.warn(f"重命名格式 {rename_format} 缺少标题目录")
|
||||
if rename_format_level > len(rename_path.parents):
|
||||
# 通常因为路径以/结尾,被Path规范化删除了
|
||||
logger.error(f"路径 {rename_path} 不匹配重命名格式 {rename_format}")
|
||||
|
||||
@@ -1,12 +1,76 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
|
||||
"""
|
||||
获取LLM实例
|
||||
:param streaming: 是否启用流式输出
|
||||
:param callbacks: 回调处理器列表
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
else:
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
|
||||
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
|
||||
@@ -90,6 +90,79 @@ class PassKeyHelper:
|
||||
logger.error(f"标准化凭证ID失败: {e}")
|
||||
return credential_id
|
||||
|
||||
@staticmethod
|
||||
def _base64_encode_urlsafe(data: bytes) -> str:
|
||||
"""
|
||||
Base64 URL Safe 编码(不带填充)
|
||||
|
||||
:param data: 要编码的字节数据
|
||||
:return: Base64 URL Safe 编码的字符串
|
||||
"""
|
||||
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
|
||||
|
||||
@staticmethod
|
||||
def _base64_decode_urlsafe(data: str) -> bytes:
|
||||
"""
|
||||
Base64 URL Safe 解码(自动添加填充)
|
||||
|
||||
:param data: Base64 URL Safe 编码的字符串
|
||||
:return: 解码后的字节数据
|
||||
"""
|
||||
return base64.urlsafe_b64decode(data + '==')
|
||||
|
||||
@staticmethod
|
||||
def _parse_credential_list(credentials: List[Dict[str, Any]]) -> List[PublicKeyCredentialDescriptor]:
|
||||
"""
|
||||
解析凭证列表为 PublicKeyCredentialDescriptor 列表
|
||||
|
||||
:param credentials: 凭证字典列表
|
||||
:return: PublicKeyCredentialDescriptor 列表
|
||||
"""
|
||||
result = []
|
||||
for cred in credentials:
|
||||
try:
|
||||
result.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=PassKeyHelper._base64_decode_urlsafe(cred['credential_id']),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_user_verification_requirement(user_verification: Optional[str] = None) -> UserVerificationRequirement:
|
||||
"""
|
||||
获取用户验证要求
|
||||
|
||||
:param user_verification: 指定的用户验证要求,如果不指定则从配置中读取
|
||||
:return: UserVerificationRequirement
|
||||
"""
|
||||
if user_verification:
|
||||
return UserVerificationRequirement(user_verification)
|
||||
return UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
|
||||
@staticmethod
|
||||
def _get_verification_params(
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
获取验证参数(origin 和 rp_id)
|
||||
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (origin, rp_id)
|
||||
"""
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
return origin, rp_id
|
||||
|
||||
@staticmethod
|
||||
def generate_registration_options(
|
||||
user_id: int,
|
||||
@@ -109,27 +182,13 @@ class PassKeyHelper:
|
||||
try:
|
||||
# 用户信息
|
||||
user_id_bytes = str(user_id).encode('utf-8')
|
||||
|
||||
|
||||
# 排除已有的凭证
|
||||
exclude_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
exclude_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
exclude_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
|
||||
if existing_credentials else None
|
||||
|
||||
# 用户验证要求
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
uv_requirement = PassKeyHelper._get_user_verification_requirement()
|
||||
|
||||
# 生成注册选项
|
||||
options = generate_registration_options(
|
||||
@@ -138,7 +197,7 @@ class PassKeyHelper:
|
||||
user_id=user_id_bytes,
|
||||
user_name=username,
|
||||
user_display_name=display_name or username,
|
||||
exclude_credentials=exclude_credentials if exclude_credentials else None,
|
||||
exclude_credentials=exclude_credentials,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
authenticator_attachment=None,
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
@@ -152,9 +211,9 @@ class PassKeyHelper:
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
|
||||
# 提取challenge(用于后续验证)
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
@@ -162,29 +221,6 @@ class PassKeyHelper:
|
||||
logger.error(f"生成注册选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _get_verified_origin(credential: Dict[str, Any], rp_id: str, default_origin: str) -> str:
|
||||
"""
|
||||
在 localhost 环境下获取并验证实际 Origin,否则返回默认值
|
||||
"""
|
||||
if not settings.APP_DOMAIN and rp_id == 'localhost':
|
||||
try:
|
||||
# 解析 clientDataJSON 获取实际的 origin
|
||||
client_data_json = json.loads(
|
||||
base64.urlsafe_b64decode(
|
||||
credential['response']['clientDataJSON'].replace('-', '+').replace('_', '/') + '=='
|
||||
).decode('utf-8')
|
||||
)
|
||||
actual_origin = client_data_json.get('origin', '')
|
||||
hostname = urlparse(actual_origin).hostname
|
||||
|
||||
if hostname in ['localhost', '127.0.0.1']:
|
||||
logger.info(f"本地环境,使用动态 origin: {actual_origin}")
|
||||
return actual_origin
|
||||
except Exception as e:
|
||||
logger.warning(f"无法提取动态 origin: {e}")
|
||||
return default_origin
|
||||
|
||||
@staticmethod
|
||||
def verify_registration_response(
|
||||
credential: Dict[str, Any],
|
||||
@@ -203,18 +239,13 @@ class PassKeyHelper:
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
|
||||
# 解码challenge
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
|
||||
|
||||
# 构建RegistrationCredential对象
|
||||
registration_credential = parse_registration_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证注册响应
|
||||
verification = verify_registration_response(
|
||||
credential=registration_credential,
|
||||
@@ -225,8 +256,8 @@ class PassKeyHelper:
|
||||
)
|
||||
|
||||
# 提取信息
|
||||
credential_id = base64.urlsafe_b64encode(verification.credential_id).decode('utf-8').rstrip('=')
|
||||
public_key = base64.urlsafe_b64encode(verification.credential_public_key).decode('utf-8').rstrip('=')
|
||||
credential_id = PassKeyHelper._base64_encode_urlsafe(verification.credential_id)
|
||||
public_key = PassKeyHelper._base64_encode_urlsafe(verification.credential_public_key)
|
||||
sign_count = verification.sign_count
|
||||
# aaguid 可能已经是字符串格式,也可能是bytes
|
||||
if verification.aaguid:
|
||||
@@ -257,41 +288,24 @@ class PassKeyHelper:
|
||||
"""
|
||||
try:
|
||||
# 允许的凭证
|
||||
allow_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
allow_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
allow_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
|
||||
if existing_credentials else None
|
||||
|
||||
# 用户验证要求
|
||||
if not user_verification:
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
else:
|
||||
uv_requirement = UserVerificationRequirement(user_verification)
|
||||
uv_requirement = PassKeyHelper._get_user_verification_requirement(user_verification)
|
||||
|
||||
# 生成认证选项
|
||||
options = generate_authentication_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
allow_credentials=allow_credentials if allow_credentials else None,
|
||||
allow_credentials=allow_credentials,
|
||||
user_verification=uv_requirement
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
|
||||
# 提取challenge
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
@@ -321,19 +335,14 @@ class PassKeyHelper:
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
|
||||
# 解码
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
public_key_bytes = base64.urlsafe_b64decode(credential_public_key + '==')
|
||||
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
|
||||
public_key_bytes = PassKeyHelper._base64_decode_urlsafe(credential_public_key)
|
||||
|
||||
# 构建AuthenticationCredential对象
|
||||
authentication_credential = parse_authentication_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证认证响应
|
||||
verification = verify_authentication_response(
|
||||
credential=authentication_credential,
|
||||
|
||||
@@ -365,7 +365,8 @@ class Discord:
|
||||
else:
|
||||
# 匹配形如 "字段:值" 的片段,字段名不允许包含常见分隔符;
|
||||
# 下一个字段需以顿号/逗号/分号等分隔开,且不能是 URL 协议开头,避免值里出现 URL 的":" 被误拆
|
||||
name_re = r"[A-Za-z0-9\u4e00-\u9fa5_\-&]+"
|
||||
# 字段名允许 emoji 等 Unicode 字符,但排除空白/分隔符/冒号
|
||||
name_re = r"[^\s::,,。;;、]+"
|
||||
pair_pattern = re.compile(
|
||||
rf"({name_re})[::](.*?)(?=(?:[,,。;;、]+\s*(?!https?://|ftp://|ftps://|magnet:){name_re}[::])|$)",
|
||||
re.IGNORECASE,
|
||||
|
||||
@@ -132,6 +132,15 @@ class TransHandler:
|
||||
return self.result.model_copy()
|
||||
else:
|
||||
new_path = target_path / fileitem.name
|
||||
# 在整理目录前先尝试获取原盘大小,避免整理记录出现0字节的情况
|
||||
# TODO 当前只计算STREAM目录内的文件大小,如果需要精确则递归完整目录
|
||||
if stream_fileitem := source_oper.get_item(
|
||||
Path(fileitem.path) / "BDMV" / "STREAM"
|
||||
):
|
||||
fileitem.size = 0
|
||||
files = source_oper.list(stream_fileitem) or []
|
||||
for file in files:
|
||||
fileitem.size += file.size
|
||||
# 整理目录
|
||||
new_diritem, errmsg = self.__transfer_dir(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.modules.indexer.spider import SiteSpider
|
||||
from app.modules.indexer.spider.haidan import HaiDanSpider
|
||||
from app.modules.indexer.spider.hddolby import HddolbySpider
|
||||
from app.modules.indexer.spider.mtorrent import MTorrentSpider
|
||||
from app.modules.indexer.spider.rousi import RousiSpider
|
||||
from app.modules.indexer.spider.tnode import TNodeSpider
|
||||
from app.modules.indexer.spider.torrentleech import TorrentLeech
|
||||
from app.modules.indexer.spider.yema import YemaSpider
|
||||
@@ -212,6 +213,13 @@ class IndexerModule(_ModuleBase):
|
||||
mtype=mtype,
|
||||
page=page
|
||||
)
|
||||
elif site.get('parser') == "RousiPro":
|
||||
error_flag, result = RousiSpider(site).search(
|
||||
keyword=search_word,
|
||||
mtype=mtype,
|
||||
cat=cat,
|
||||
page=page
|
||||
)
|
||||
else:
|
||||
error_flag, result = self.__spider_search(
|
||||
search_word=search_word,
|
||||
@@ -300,6 +308,13 @@ class IndexerModule(_ModuleBase):
|
||||
mtype=mtype,
|
||||
page=page
|
||||
)
|
||||
elif site.get('parser') == "RousiPro":
|
||||
error_flag, result = await RousiSpider(site).async_search(
|
||||
keyword=search_word,
|
||||
mtype=mtype,
|
||||
cat=cat,
|
||||
page=page
|
||||
)
|
||||
else:
|
||||
error_flag, result = await self.__async_spider_search(
|
||||
search_word=search_word,
|
||||
|
||||
@@ -35,6 +35,7 @@ class SiteSchema(Enum):
|
||||
HDDolby = "HDDolby"
|
||||
Zhixing = "Zhixing"
|
||||
Bitpt = "Bitpt"
|
||||
RousiPro = "RousiPro"
|
||||
|
||||
|
||||
class SiteParserBase(metaclass=ABCMeta):
|
||||
|
||||
233
app/modules/indexer/parser/rousi.py
Normal file
233
app/modules/indexer/parser/rousi.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from app.log import logger
|
||||
from app.core.config import settings
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
|
||||
|
||||
class RousiSiteUserInfo(SiteParserBase):
|
||||
"""
|
||||
Rousi.pro 站点解析器
|
||||
使用 API v1 接口,通过 Passkey (Bearer Token) 进行认证
|
||||
"""
|
||||
schema = SiteSchema.RousiPro
|
||||
request_mode = "apikey"
|
||||
|
||||
def _parse_site_page(self, html_text: str):
|
||||
"""
|
||||
配置 API 请求地址和请求头
|
||||
使用 API v1 的 /profile 接口获取用户信息
|
||||
"""
|
||||
self._base_url = f"https://{StringUtils.get_url_domain(self._site_url)}"
|
||||
self._user_basic_page = "api/v1/profile?include_fields[user]=seeding_leeching_data"
|
||||
self._user_basic_params = {}
|
||||
self._user_basic_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {self.apikey}"
|
||||
}
|
||||
|
||||
# Rousi.pro API v1 在单个接口返回所有信息,无需额外页面
|
||||
self._user_traffic_page = None
|
||||
self._user_detail_page = None
|
||||
self._torrent_seeding_page = None
|
||||
self._user_mail_unread_page = None
|
||||
self._sys_mail_unread_page = None
|
||||
|
||||
def _parse_logged_in(self, html_text):
|
||||
"""
|
||||
判断是否登录成功
|
||||
API 认证模式下,通过 HTTP 状态码判断,此处始终返回 True
|
||||
"""
|
||||
return True
|
||||
|
||||
def _parse_user_base_info(self, html_text: str):
|
||||
"""
|
||||
解析用户基本信息
|
||||
通过 API v1 接口获取用户完整信息,包括上传下载量、做种数据等
|
||||
|
||||
API 响应示例:
|
||||
{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"id": 1,
|
||||
"username": "example",
|
||||
"level_text": "Lv.5",
|
||||
"registered_at": "2024-01-01T00:00:00Z",
|
||||
"uploaded": 1073741824,
|
||||
"downloaded": 536870912,
|
||||
"ratio": 2.0,
|
||||
"karma": 1000.5,
|
||||
"seeding_leeching_data": {
|
||||
"seeding_count": 10,
|
||||
"seeding_size": 10737418240,
|
||||
"leeching_count": 2,
|
||||
"leeching_size": 2147483648
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not html_text:
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(html_text)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"{self._site_name} JSON 解析失败")
|
||||
return
|
||||
|
||||
if not data or data.get("code") != 0:
|
||||
self.err_msg = data.get("message", "未知错误")
|
||||
logger.warn(f"{self._site_name} API 错误: {self.err_msg}")
|
||||
return
|
||||
|
||||
user_info = data.get("data")
|
||||
if not user_info:
|
||||
return
|
||||
|
||||
# 基本信息
|
||||
self.userid = user_info.get("id")
|
||||
self.username = user_info.get("username")
|
||||
self.user_level = user_info.get("level_text") or user_info.get("role_text")
|
||||
|
||||
# 注册时间:统一格式为 YYYY-MM-DD HH:MM:SS
|
||||
join_at = StringUtils.unify_datetime_str(user_info.get("registered_at"))
|
||||
if join_at:
|
||||
# 确保格式为 YYYY-MM-DD HH:MM:SS (19位)
|
||||
if len(join_at) >= 19:
|
||||
self.join_at = join_at[:19]
|
||||
else:
|
||||
self.join_at = join_at
|
||||
|
||||
# 流量信息
|
||||
self.upload = int(user_info.get("uploaded") or 0)
|
||||
self.download = int(user_info.get("downloaded") or 0)
|
||||
self.ratio = round(float(user_info.get("ratio") or 0), 2)
|
||||
|
||||
# 魔力值(站点称为 karma)
|
||||
self.bonus = float(user_info.get("karma") or 0)
|
||||
|
||||
# 做种/下载中数据
|
||||
sl_data = user_info.get("seeding_leeching_data", {})
|
||||
self.seeding = int(sl_data.get("seeding_count") or 0)
|
||||
self.seeding_size = int(sl_data.get("seeding_size") or 0)
|
||||
self.leeching = int(sl_data.get("leeching_count") or 0)
|
||||
self.leeching_size = int(sl_data.get("leeching_size") or 0)
|
||||
|
||||
def _parse_user_traffic_info(self, html_text: str):
|
||||
"""
|
||||
解析用户流量信息
|
||||
Rousi.pro API v1 在 _parse_user_base_info 中已完成所有解析,此方法无需实现
|
||||
"""
|
||||
pass
|
||||
|
||||
def _parse_user_detail_info(self, html_text: str):
|
||||
"""
|
||||
解析用户详细信息
|
||||
Rousi.pro API v1 在 _parse_user_base_info 中已完成所有解析,此方法无需实现
|
||||
"""
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str, multi_page: Optional[bool] = False) -> Optional[str]:
|
||||
"""
|
||||
解析用户做种信息
|
||||
Rousi.pro API v1 在 _parse_user_base_info 中已通过 seeding_leeching_data 获取做种数据
|
||||
|
||||
:param html_text: 页面内容
|
||||
:param multi_page: 是否多页数据
|
||||
:return: 下页地址(无下页返回 None)
|
||||
"""
|
||||
return None
|
||||
|
||||
def _parse_message_unread_links(self, html_text: str, msg_links: list) -> Optional[str]:
|
||||
"""
|
||||
解析未读消息链接
|
||||
Rousi.pro API v1 暂未提供消息相关接口
|
||||
|
||||
:param html_text: 页面内容
|
||||
:param msg_links: 消息链接列表
|
||||
:return: 下页地址(无下页返回 None)
|
||||
"""
|
||||
return None
|
||||
|
||||
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
解析消息内容
|
||||
Rousi.pro API v1 暂未提供消息相关接口
|
||||
|
||||
:param html_text: 页面内容
|
||||
:return: (标题, 日期, 内容)
|
||||
"""
|
||||
return None, None, None
|
||||
|
||||
def _pase_unread_msgs(self):
|
||||
"""
|
||||
解析所有未读消息标题和内容
|
||||
Rousi.pro API v1 暂未提供消息相关接口,暂时以网页接口实现
|
||||
|
||||
:return:
|
||||
"""
|
||||
if not self.token:
|
||||
logger.warn(f"{self._site_name} 站点未配置 Authorization 请求头,跳过消息解析")
|
||||
return
|
||||
|
||||
headers = {
|
||||
"User-Agent": self._ua,
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Authorization": self.token if self.token.startswith("Bearer ") else f"Bearer {self.token}"
|
||||
}
|
||||
|
||||
def __get_message_list(page: int):
|
||||
params = {
|
||||
"page": page,
|
||||
"page_size": 100,
|
||||
"unread_only": "true"
|
||||
}
|
||||
res = RequestUtils(
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
proxies=settings.PROXY if self._proxy else None
|
||||
).get_res(
|
||||
url=urljoin(self._base_url, "api/messages"),
|
||||
params=params
|
||||
)
|
||||
if not res or res.status_code != 200 or not res.text:
|
||||
logger.warn(f"{self._site_name} 站点解析消息失败,状态码: {res.status_code if res else '无响应'}")
|
||||
return {
|
||||
"messages": [],
|
||||
"total_pages": 0
|
||||
}
|
||||
return res.json()
|
||||
|
||||
# 分页获取所有未读消息
|
||||
page = 0
|
||||
res = __get_message_list(page)
|
||||
page += 1
|
||||
messages = res.get("messages", [])
|
||||
total_pages = res.get("total_pages", 0)
|
||||
while page < total_pages:
|
||||
res = __get_message_list(page)
|
||||
messages.extend(res.get("messages", []))
|
||||
page += 1
|
||||
|
||||
for messsage in messages:
|
||||
head = messsage.get("title")
|
||||
date = StringUtils.unify_datetime_str(messsage.get("created_at"))
|
||||
content = messsage.get("content")
|
||||
logger.debug(f"{self._site_name} 标题 {head} 时间 {date} 内容 {content}")
|
||||
self.message_unread_contents.append((head, date, content))
|
||||
|
||||
# 更新消息为已读
|
||||
RequestUtils(
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
proxies=settings.PROXY if self._proxy else None
|
||||
).post_res(
|
||||
url=urljoin(self._base_url, "api/messages/read-all")
|
||||
)
|
||||
@@ -118,24 +118,36 @@ class MTorrentSpider:
|
||||
labels_value = self._labels.get(result.get('labels') or "0") or ""
|
||||
if labels_value:
|
||||
labels = labels_value.split()
|
||||
status = result.get('status', {})
|
||||
torrent = {
|
||||
'title': result.get('name'),
|
||||
'description': result.get('smallDescr'),
|
||||
'enclosure': self.__get_download_url(result.get('id')),
|
||||
'pubdate': StringUtils.format_timestamp(result.get('createdDate')),
|
||||
'size': int(result.get('size') or '0'),
|
||||
'seeders': int(result.get('status', {}).get("seeders") or '0'),
|
||||
'peers': int(result.get('status', {}).get("leechers") or '0'),
|
||||
'grabs': int(result.get('status', {}).get("timesCompleted") or '0'),
|
||||
'downloadvolumefactor': self.__get_downloadvolumefactor(result.get('status', {}).get("discount")),
|
||||
'uploadvolumefactor': self.__get_uploadvolumefactor(result.get('status', {}).get("discount")),
|
||||
'seeders': int(status.get("seeders") or '0'),
|
||||
'peers': int(status.get("leechers") or '0'),
|
||||
'grabs': int(status.get("timesCompleted") or '0'),
|
||||
'downloadvolumefactor': self.__get_downloadvolumefactor(status.get("discount")),
|
||||
'uploadvolumefactor': self.__get_uploadvolumefactor(status.get("discount")),
|
||||
'page_url': self._pageurl % (self._url, result.get('id')),
|
||||
'imdbid': self.__find_imdbid(result.get('imdb')),
|
||||
'labels': labels,
|
||||
'category': category
|
||||
}
|
||||
if discount_end_time := (result.get('status') or {}).get('discountEndTime'):
|
||||
if discount_end_time := status.get('discountEndTime'):
|
||||
torrent['freedate'] = StringUtils.format_timestamp(discount_end_time)
|
||||
# 解析全站促销时的规则(当前馒头只有下载促销)
|
||||
if promotion_rule := status.get("promotionRule"):
|
||||
discount = promotion_rule.get("discount", "NORMAL")
|
||||
torrent["downloadvolumefactor"] = self.__get_downloadvolumefactor(discount)
|
||||
if end_time := promotion_rule.get("endTime"):
|
||||
torrent["freedate"] = StringUtils.format_timestamp(end_time)
|
||||
if mall_single_free := status.get("mallSingleFree"):
|
||||
if mall_single_free.get("status") == "ONGOING":
|
||||
torrent["downloadvolumefactor"] = self.__get_downloadvolumefactor("FREE")
|
||||
if end_date := mall_single_free.get("endDate"):
|
||||
torrent["freedate"] = StringUtils.format_timestamp(end_date)
|
||||
torrents.append(torrent)
|
||||
return torrents
|
||||
|
||||
|
||||
289
app/modules/indexer/spider/rousi.py
Normal file
289
app/modules/indexer/spider/rousi.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import base64
|
||||
import json
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class RousiSpider:
|
||||
"""
|
||||
Rousi.pro API v1 Spider
|
||||
|
||||
使用 API v1 接口进行种子搜索
|
||||
- 认证方式:Bearer Token (Passkey)
|
||||
- 搜索接口:/api/v1/torrents
|
||||
- 详情接口:/api/v1/torrents/:id
|
||||
"""
|
||||
_indexerid = None
|
||||
_domain = None
|
||||
_url = None
|
||||
_name = ""
|
||||
_proxy = None
|
||||
_cookie = None
|
||||
_ua = None
|
||||
_size = 100
|
||||
_searchurl = "https://%s/api/v1/torrents"
|
||||
_downloadurl = "https://%s/api/v1/torrents/%s"
|
||||
_timeout = 15
|
||||
|
||||
# 分类定义
|
||||
# API 不支持多分类搜索,每次只使用一个分类
|
||||
_movie_category = 'movie'
|
||||
_tv_category = 'tv'
|
||||
|
||||
# API KEY
|
||||
_apikey = None
|
||||
|
||||
def __init__(self, indexer: dict):
|
||||
self.systemconfig = SystemConfigOper()
|
||||
if indexer:
|
||||
self._indexerid = indexer.get('id')
|
||||
self._url = indexer.get('domain')
|
||||
self._domain = StringUtils.get_url_domain(self._url)
|
||||
self._searchurl = self._searchurl % self._domain
|
||||
self._downloadurl = self._downloadurl % (self._domain, "%s")
|
||||
self._name = indexer.get('name')
|
||||
if indexer.get('proxy'):
|
||||
self._proxy = settings.PROXY
|
||||
self._cookie = indexer.get('cookie')
|
||||
self._ua = indexer.get('ua')
|
||||
self._apikey = indexer.get('apikey')
|
||||
self._timeout = indexer.get('timeout') or 15
|
||||
|
||||
def __get_params(self, keyword: str, mtype: MediaType = None, cat: Optional[str] = None, page: Optional[int] = 0) -> dict:
|
||||
"""
|
||||
构建 API 请求参数
|
||||
|
||||
:param keyword: 搜索关键词
|
||||
:param mtype: 媒体类型 (MOVIE/TV)
|
||||
:param cat: 用户选择的分类 ID(逗号分隔的字符串)
|
||||
:param page: 页码(从 0 开始,API 需要从 1 开始)
|
||||
:return: 请求参数字典
|
||||
"""
|
||||
params = {
|
||||
"page": int(page) + 1,
|
||||
"page_size": self._size
|
||||
}
|
||||
if keyword:
|
||||
params["keyword"] = keyword
|
||||
|
||||
# API 不支持多分类搜索,只使用单个 category 参数
|
||||
# 优先使用用户选择的分类,如果用户未选择则根据 mtype 推断
|
||||
if cat:
|
||||
# 用户选择了特定分类,需要将分类 ID 映射回 API 的 category name
|
||||
category_names = self.__get_category_names_by_ids(cat)
|
||||
if category_names:
|
||||
# 如果用户选择了多个分类,只取第一个
|
||||
params["category"] = category_names[0]
|
||||
elif mtype:
|
||||
# 用户未选择分类,根据媒体类型推断
|
||||
if mtype == MediaType.MOVIE:
|
||||
params["category"] = self._movie_category
|
||||
elif mtype == MediaType.TV:
|
||||
params["category"] = self._tv_category
|
||||
|
||||
return params
|
||||
|
||||
def __get_category_names_by_ids(self, cat: str) -> Optional[list]:
|
||||
"""
|
||||
根据用户选择的分类 ID 获取 API 的 category names
|
||||
|
||||
:param cat: 用户选择的分类 ID(逗号分隔的多个ID,如 "1,2,3")
|
||||
:return: API 的 category names 列表(如 ["movie", "tv", "documentary"])
|
||||
"""
|
||||
if not cat:
|
||||
return None
|
||||
|
||||
# ID 到 category name 的映射
|
||||
id_to_name = {
|
||||
'1': 'movie',
|
||||
'2': 'tv',
|
||||
'3': 'documentary',
|
||||
'4': 'animation',
|
||||
'6': 'variety'
|
||||
}
|
||||
|
||||
# 分割多个分类 ID 并映射为 category names
|
||||
cat_ids = [c.strip() for c in cat.split(',') if c.strip()]
|
||||
category_names = [id_to_name.get(cat_id) for cat_id in cat_ids if cat_id in id_to_name]
|
||||
|
||||
return category_names if category_names else None
|
||||
|
||||
def __process_response(self, res) -> Tuple[bool, List[dict]]:
|
||||
"""
|
||||
处理 API 响应
|
||||
|
||||
:param res: 请求响应对象
|
||||
:return: (是否发生错误, 种子列表)
|
||||
"""
|
||||
if res and res.status_code == 200:
|
||||
try:
|
||||
data = res.json()
|
||||
if data.get('code') == 0:
|
||||
results = data.get('data', {}).get('torrents', [])
|
||||
return False, self.__parse_result(results)
|
||||
else:
|
||||
logger.warn(f"{self._name} 搜索失败,错误信息:{data.get('message')}")
|
||||
return True, []
|
||||
except Exception as e:
|
||||
logger.warn(f"{self._name} 解析响应失败:{e}")
|
||||
return True, []
|
||||
elif res is not None:
|
||||
logger.warn(f"{self._name} 搜索失败,HTTP 错误码:{res.status_code}")
|
||||
return True, []
|
||||
else:
|
||||
logger.warn(f"{self._name} 搜索失败,无法连接 {self._domain}")
|
||||
return True, []
|
||||
|
||||
def __parse_result(self, results: List[dict]) -> List[dict]:
|
||||
"""
|
||||
解析搜索结果
|
||||
|
||||
将 API 返回的种子数据转换为 MoviePilot 标准格式
|
||||
|
||||
:param results: API 返回的种子列表
|
||||
:return: 标准化的种子信息列表
|
||||
"""
|
||||
torrents = []
|
||||
if not results:
|
||||
return torrents
|
||||
|
||||
for result in results:
|
||||
# 解析分类信息
|
||||
raw_cat = result.get('category')
|
||||
cat_val = None
|
||||
|
||||
category = MediaType.UNKNOWN.value
|
||||
|
||||
if isinstance(raw_cat, dict):
|
||||
cat_val = raw_cat.get('slug') or raw_cat.get('name')
|
||||
elif isinstance(raw_cat, str):
|
||||
cat_val = raw_cat
|
||||
|
||||
if cat_val:
|
||||
cat_val = str(cat_val).lower()
|
||||
if cat_val == self._movie_category:
|
||||
category = MediaType.MOVIE.value
|
||||
elif cat_val == self._tv_category:
|
||||
category = MediaType.TV.value
|
||||
else:
|
||||
category = MediaType.UNKNOWN.value
|
||||
|
||||
# 解析促销信息
|
||||
# API 后端已处理全站促销优先级,直接使用返回的 promotion 数据
|
||||
downloadvolumefactor = 1.0
|
||||
uploadvolumefactor = 1.0
|
||||
freedate = None
|
||||
|
||||
promotion = result.get('promotion')
|
||||
if promotion and promotion.get('is_active'):
|
||||
downloadvolumefactor = float(promotion.get('down_multiplier', 1.0))
|
||||
uploadvolumefactor = float(promotion.get('up_multiplier', 1.0))
|
||||
# 促销到期时间,格式化为 YYYY-MM-DD HH:MM:SS
|
||||
if promotion.get('until'):
|
||||
freedate = StringUtils.unify_datetime_str(promotion.get('until'))
|
||||
|
||||
torrent = {
|
||||
'title': result.get('title'),
|
||||
'description': result.get('subtitle'),
|
||||
'enclosure': self.__get_download_url(result.get('id')),
|
||||
'pubdate': StringUtils.unify_datetime_str(result.get('created_at')),
|
||||
'size': int(result.get('size') or 0),
|
||||
'seeders': int(result.get('seeders') or 0),
|
||||
'peers': int(result.get('leechers') or 0),
|
||||
'grabs': int(result.get('downloads') or 0),
|
||||
'downloadvolumefactor': downloadvolumefactor,
|
||||
'uploadvolumefactor': uploadvolumefactor,
|
||||
'freedate': freedate,
|
||||
'page_url': f"https://{self._domain}/torrent/{result.get('uuid')}",
|
||||
'labels': [],
|
||||
'category': category
|
||||
}
|
||||
torrents.append(torrent)
|
||||
return torrents
|
||||
|
||||
def search(self, keyword: str, mtype: MediaType = None, cat: Optional[str] = None, page: Optional[int] = 0) -> Tuple[bool, List[dict]]:
|
||||
"""
|
||||
同步搜索种子
|
||||
|
||||
:param keyword: 搜索关键词
|
||||
:param mtype: 媒体类型 (MOVIE/TV)
|
||||
:param cat: 用户选择的分类 ID(逗号分隔)
|
||||
:param page: 页码(从 0 开始)
|
||||
:return: (是否发生错误, 种子列表)
|
||||
"""
|
||||
if not self._apikey:
|
||||
logger.warn(f"{self._name} 未配置 API Key (Passkey)")
|
||||
return True, []
|
||||
|
||||
params = self.__get_params(keyword, mtype, cat, page)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._apikey}",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
res = RequestUtils(
|
||||
headers=headers,
|
||||
proxies=self._proxy,
|
||||
timeout=self._timeout
|
||||
).get_res(url=self._searchurl, params=params)
|
||||
|
||||
return self.__process_response(res)
|
||||
|
||||
async def async_search(self, keyword: str, mtype: MediaType = None, cat: Optional[str] = None, page: Optional[int] = 0) -> Tuple[bool, List[dict]]:
|
||||
"""
|
||||
异步搜索种子
|
||||
|
||||
:param keyword: 搜索关键词
|
||||
:param mtype: 媒体类型 (MOVIE/TV)
|
||||
:param cat: 用户选择的分类 ID(逗号分隔)
|
||||
:param page: 页码(从 0 开始)
|
||||
:return: (是否发生错误, 种子列表)
|
||||
"""
|
||||
if not self._apikey:
|
||||
logger.warn(f"{self._name} 未配置 API Key (Passkey)")
|
||||
return True, []
|
||||
|
||||
params = self.__get_params(keyword, mtype, cat, page)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._apikey}",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
res = await AsyncRequestUtils(
|
||||
headers=headers,
|
||||
proxies=self._proxy,
|
||||
timeout=self._timeout
|
||||
).get_res(url=self._searchurl, params=params)
|
||||
|
||||
return self.__process_response(res)
|
||||
|
||||
def __get_download_url(self, torrent_id: int) -> str:
|
||||
"""
|
||||
构建种子下载链接
|
||||
|
||||
使用 base64 编码的方式告诉 MoviePilot 如何获取真实下载地址
|
||||
MoviePilot 会先请求详情接口,然后从响应中提取 data.download_url
|
||||
|
||||
:param torrent_id: 种子 ID
|
||||
:return: base64 编码的请求配置字符串 + 详情接口 URL
|
||||
"""
|
||||
url = self._downloadurl % torrent_id
|
||||
# MoviePilot 会解析这个特殊格式的 URL:
|
||||
# 1. 使用指定的 method 和 header 请求 URL
|
||||
# 2. 从 JSON 响应中提取 result 指定的字段值作为真实下载地址
|
||||
params = {
|
||||
'method': 'get',
|
||||
'header': {
|
||||
'Authorization': f'Bearer {self._apikey}',
|
||||
'Accept': 'application/json'
|
||||
},
|
||||
'result': 'data.download_url'
|
||||
}
|
||||
base64_str = base64.b64encode(json.dumps(params).encode('utf-8')).decode('utf-8')
|
||||
return f"[{base64_str}]{url}"
|
||||
@@ -85,7 +85,11 @@ class SubtitleModule(_ModuleBase):
|
||||
)
|
||||
# TODO 其它采用API访问的站点
|
||||
# 普通站点通过解析网站代码的方式获取
|
||||
request = RequestUtils(cookies=torrent.site_cookie, ua=torrent.site_ua)
|
||||
request = RequestUtils(
|
||||
cookies=torrent.site_cookie,
|
||||
ua=torrent.site_ua,
|
||||
proxies=settings.PROXY if torrent.site_proxy else None,
|
||||
)
|
||||
res = request.get_res(torrent.page_url)
|
||||
if res and res.status_code == 200:
|
||||
if not res.text:
|
||||
@@ -176,7 +180,11 @@ class SubtitleModule(_ModuleBase):
|
||||
logger.warn(f"{torrent.page_url} 页面未找到字幕下载链接")
|
||||
return
|
||||
# 下载所有字幕文件
|
||||
request = RequestUtils(cookies=torrent.site_cookie, ua=torrent.site_ua)
|
||||
request = RequestUtils(
|
||||
cookies=torrent.site_cookie,
|
||||
ua=torrent.site_ua,
|
||||
proxies=settings.PROXY if torrent.site_proxy else None,
|
||||
)
|
||||
for sublink in sublink_list:
|
||||
logger.info(f"找到字幕下载链接:{sublink},开始下载...")
|
||||
# 下载
|
||||
|
||||
@@ -867,19 +867,19 @@ class TheMovieDbModule(_ModuleBase):
|
||||
backdrops = images.get("backdrops")
|
||||
if backdrops:
|
||||
backdrops = sorted(backdrops, key=lambda x: x.get("vote_average"), reverse=True)
|
||||
mediainfo.backdrop_path = backdrops[0].get("file_path")
|
||||
mediainfo.backdrop_path = settings.TMDB_IMAGE_URL(backdrops[0].get("file_path"))
|
||||
# 标志
|
||||
if not mediainfo.logo_path:
|
||||
logos = images.get("logos")
|
||||
if logos:
|
||||
logos = sorted(logos, key=lambda x: x.get("vote_average"), reverse=True)
|
||||
mediainfo.logo_path = logos[0].get("file_path")
|
||||
mediainfo.logo_path = settings.TMDB_IMAGE_URL(logos[0].get("file_path"))
|
||||
# 海报
|
||||
if not mediainfo.poster_path:
|
||||
posters = images.get("posters")
|
||||
if posters:
|
||||
posters = sorted(posters, key=lambda x: x.get("vote_average"), reverse=True)
|
||||
mediainfo.poster_path = posters[0].get("file_path")
|
||||
mediainfo.poster_path = settings.TMDB_IMAGE_URL(posters[0].get("file_path"))
|
||||
return mediainfo
|
||||
|
||||
def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
|
||||
@@ -957,7 +957,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
image_path = seasoninfo.get(image_type.value)
|
||||
|
||||
if image_path:
|
||||
return f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/{image_prefix}{image_path}"
|
||||
return settings.TMDB_IMAGE_URL(image_path, image_prefix)
|
||||
return None
|
||||
|
||||
def tmdb_movie_similar(self, tmdbid: int) -> List[MediaInfo]:
|
||||
|
||||
@@ -85,10 +85,10 @@ class TmdbScraper:
|
||||
seasoninfo = self.original_tmdb(mediainfo).get_tv_season_detail(mediainfo.tmdb_id, season)
|
||||
if seasoninfo:
|
||||
episodeinfo = self.__get_episode_detail(seasoninfo, episode)
|
||||
if episodeinfo and episodeinfo.get("still_path"):
|
||||
if still_path := episodeinfo.get("still_path"):
|
||||
# TMDB集still图片
|
||||
still_name = f"{episode}"
|
||||
still_url = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{episodeinfo.get('still_path')}"
|
||||
still_url = settings.TMDB_IMAGE_URL(still_path)
|
||||
images[still_name] = still_url
|
||||
else:
|
||||
# 季的图片
|
||||
@@ -115,7 +115,7 @@ class TmdbScraper:
|
||||
if _mediainfo:
|
||||
for attr_name, attr_value in _mediainfo.items():
|
||||
if attr_name.endswith("_path") and attr_value is not None:
|
||||
image_url = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{attr_value}"
|
||||
image_url = settings.TMDB_IMAGE_URL(attr_value)
|
||||
image_name = attr_name.replace("_path", "") + Path(image_url).suffix
|
||||
images[image_name] = image_url
|
||||
return images
|
||||
@@ -127,11 +127,11 @@ class TmdbScraper:
|
||||
"""
|
||||
# TMDB季poster图片
|
||||
sea_seq = str(season).rjust(2, '0')
|
||||
if seasoninfo.get("poster_path"):
|
||||
if poster_path := seasoninfo.get("poster_path"):
|
||||
# 后缀
|
||||
ext = Path(seasoninfo.get('poster_path')).suffix
|
||||
ext = Path(poster_path).suffix
|
||||
# URL
|
||||
url = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{seasoninfo.get('poster_path')}"
|
||||
url = settings.TMDB_IMAGE_URL(poster_path)
|
||||
# S0海报格式不同
|
||||
if season == 0:
|
||||
image_name = f"season-specials-poster{ext}"
|
||||
@@ -190,8 +190,8 @@ class TmdbScraper:
|
||||
DomUtils.add_node(doc, xactor, "type", "Actor")
|
||||
DomUtils.add_node(doc, xactor, "role", actor.get("character") or actor.get("role") or "")
|
||||
DomUtils.add_node(doc, xactor, "tmdbid", actor.get("id") or "")
|
||||
DomUtils.add_node(doc, xactor, "thumb",
|
||||
f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{actor.get('profile_path')}")
|
||||
if profile_path := actor.get('profile_path'):
|
||||
DomUtils.add_node(doc, xactor, "thumb", settings.TMDB_IMAGE_URL(profile_path))
|
||||
DomUtils.add_node(doc, xactor, "profile",
|
||||
f"https://www.themoviedb.org/person/{actor.get('id')}")
|
||||
# 风格
|
||||
@@ -297,7 +297,8 @@ class TmdbScraper:
|
||||
uniqueid.setAttribute("type", "tmdb")
|
||||
uniqueid.setAttribute("default", "true")
|
||||
# tmdbid
|
||||
DomUtils.add_node(doc, root, "tmdbid", str(tmdbid))
|
||||
# 应与uniqueid一致 使用剧集id 否则jellyfin/emby会将此id覆盖上面的uniqueid
|
||||
DomUtils.add_node(doc, root, "tmdbid", str(episodeinfo.get("id")))
|
||||
# 标题
|
||||
DomUtils.add_node(doc, root, "title", episodeinfo.get("name") or "第 %s 集" % episode)
|
||||
# 简介
|
||||
@@ -330,8 +331,8 @@ class TmdbScraper:
|
||||
DomUtils.add_node(doc, xactor, "name", actor.get("name") or "")
|
||||
DomUtils.add_node(doc, xactor, "type", "Actor")
|
||||
DomUtils.add_node(doc, xactor, "tmdbid", actor.get("id") or "")
|
||||
DomUtils.add_node(doc, xactor, "thumb",
|
||||
f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{actor.get('profile_path')}")
|
||||
if profile_path := actor.get('profile_path'):
|
||||
DomUtils.add_node(doc, xactor, "thumb", settings.TMDB_IMAGE_URL(profile_path))
|
||||
DomUtils.add_node(doc, xactor, "profile",
|
||||
f"https://www.themoviedb.org/person/{actor.get('id')}")
|
||||
return doc
|
||||
|
||||
@@ -50,7 +50,7 @@ class TmdbCache(metaclass=WeakSingleton):
|
||||
"""
|
||||
获取缓存KEY
|
||||
"""
|
||||
return f"[{meta.type.value if meta.type else '未知'}]{meta.tmdbid or meta.name}-{meta.year}-{meta.begin_season}"
|
||||
return f"[{meta.type.value if meta.type else '未知'}][{settings.TMDB_LOCALE}]{meta.tmdbid or meta.name}-{meta.year}-{meta.begin_season}"
|
||||
|
||||
def get(self, meta: MetaBase):
|
||||
"""
|
||||
|
||||
@@ -826,7 +826,7 @@ class TmdbApi:
|
||||
# 转换多语种标题
|
||||
self.__update_tmdbinfo_extra_title(tmdb_info)
|
||||
# 转换中文标题
|
||||
if settings.TMDB_LOCALE == "zh":
|
||||
if self.tmdb.language in ("zh", "zh-CN"):
|
||||
self.__update_tmdbinfo_cn_title(tmdb_info)
|
||||
|
||||
return tmdb_info
|
||||
@@ -2134,7 +2134,7 @@ class TmdbApi:
|
||||
# 转换多语种标题
|
||||
self.__update_tmdbinfo_extra_title(tmdb_info)
|
||||
# 转换中文标题
|
||||
if settings.TMDB_LOCALE == "zh":
|
||||
if self.tmdb.language in ("zh", "zh-CN"):
|
||||
self.__update_tmdbinfo_cn_title(tmdb_info)
|
||||
|
||||
return tmdb_info
|
||||
|
||||
@@ -695,11 +695,13 @@ class Monitor(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
|
||||
# 全程加锁
|
||||
with lock:
|
||||
is_bluray_folder = False
|
||||
# 蓝光原盘文件处理
|
||||
if __is_bluray_sub(event_path):
|
||||
event_path = __get_bluray_dir(event_path)
|
||||
if not event_path:
|
||||
return
|
||||
is_bluray_folder = True
|
||||
|
||||
# TTL缓存控重
|
||||
if self._cache.get(str(event_path)):
|
||||
@@ -708,13 +710,20 @@ class Monitor(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
self._cache[str(event_path)] = True
|
||||
|
||||
try:
|
||||
logger.info(f"开始整理文件: {event_path}")
|
||||
if is_bluray_folder:
|
||||
logger.info(f"开始整理蓝光原盘: {event_path}")
|
||||
else:
|
||||
logger.info(f"开始整理文件: {event_path}")
|
||||
# 开始整理
|
||||
TransferChain().do_transfer(
|
||||
fileitem=FileItem(
|
||||
storage=storage,
|
||||
path=event_path.as_posix(),
|
||||
type="file",
|
||||
path=(
|
||||
event_path.as_posix()
|
||||
if not is_bluray_folder
|
||||
else event_path.as_posix() + "/"
|
||||
),
|
||||
type="file" if not is_bluray_folder else "dir",
|
||||
name=event_path.name,
|
||||
basename=event_path.stem,
|
||||
extension=event_path.suffix[1:],
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""AI智能体初始化器"""
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_manager = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
@@ -26,9 +23,6 @@ class AgentInitializer:
|
||||
logger.info("AI智能体功能未启用")
|
||||
return True
|
||||
|
||||
from app.agent import agent_manager
|
||||
self.agent_manager = agent_manager
|
||||
|
||||
await agent_manager.initialize()
|
||||
self._initialized = True
|
||||
logger.info("AI智能体管理器初始化成功")
|
||||
@@ -43,10 +37,10 @@ class AgentInitializer:
|
||||
清理AI智能体管理器
|
||||
"""
|
||||
try:
|
||||
if not self._initialized or not self.agent_manager:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
await self.agent_manager.close()
|
||||
await agent_manager.close()
|
||||
self._initialized = False
|
||||
logger.info("AI智能体管理器已关闭")
|
||||
|
||||
@@ -78,8 +72,8 @@ def init_agent():
|
||||
else:
|
||||
logger.error("AI智能体管理器初始化失败")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"AI智能体管理器初始化失败: {e}")
|
||||
except Exception as err:
|
||||
logger.error(f"AI智能体管理器初始化失败: {err}")
|
||||
return False
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@@ -26,12 +26,14 @@ def cookie_parse(cookies_str: str, array: bool = False) -> Union[list, dict]:
|
||||
"""
|
||||
if not cookies_str:
|
||||
return {}
|
||||
from urllib.parse import unquote
|
||||
cookie_dict = {}
|
||||
cookies = cookies_str.split(";")
|
||||
for cookie in cookies:
|
||||
cstr = cookie.split("=")
|
||||
cstr = cookie.split("=", 1) # 只分割第一个=,因为value可能包含=
|
||||
if len(cstr) > 1:
|
||||
cookie_dict[cstr[0].strip()] = cstr[1].strip()
|
||||
# URL解码Cookie值(但保留Cookie名不解码)
|
||||
cookie_dict[cstr[0].strip()] = unquote(cstr[1].strip())
|
||||
if array:
|
||||
return [{"name": k, "value": v} for k, v in cookie_dict.items()]
|
||||
return cookie_dict
|
||||
@@ -654,7 +656,8 @@ class AsyncRequestUtils:
|
||||
proxy=self._proxies,
|
||||
timeout=self._timeout,
|
||||
verify=False,
|
||||
follow_redirects=True
|
||||
follow_redirects=True,
|
||||
cookies=self._cookies # 在创建客户端时传入Cookie
|
||||
) as client:
|
||||
return await self._make_request(client, method, url, raise_exception, **kwargs)
|
||||
else:
|
||||
@@ -666,7 +669,8 @@ class AsyncRequestUtils:
|
||||
执行实际的异步请求
|
||||
"""
|
||||
kwargs.setdefault("headers", self._headers)
|
||||
kwargs.setdefault("cookies", self._cookies)
|
||||
# Cookie已经在AsyncClient创建时设置,不要在request时再设置,否则会被覆盖
|
||||
# kwargs.setdefault("cookies", self._cookies)
|
||||
|
||||
try:
|
||||
return await client.request(method, url, **kwargs)
|
||||
|
||||
@@ -242,6 +242,27 @@ class StringUtils:
|
||||
else:
|
||||
return size + "B"
|
||||
|
||||
@staticmethod
|
||||
def format_size(size_bytes: int) -> str:
|
||||
"""
|
||||
将字节转换为人类可读格式
|
||||
"""
|
||||
if not size_bytes or size_bytes == 0:
|
||||
return "0 B"
|
||||
|
||||
units = ["B", "KB", "MB", "GB", "TB", "PB"]
|
||||
size = float(size_bytes)
|
||||
unit_index = 0
|
||||
|
||||
while size >= 1024 and unit_index < len(units) - 1:
|
||||
size /= 1024
|
||||
unit_index += 1
|
||||
|
||||
# 保留两位小数
|
||||
if unit_index == 0:
|
||||
return f"{int(size)} {units[unit_index]}"
|
||||
return f"{size:.2f} {units[unit_index]}"
|
||||
|
||||
@staticmethod
|
||||
def url_equal(url1: str, url2: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -479,6 +479,8 @@ class SystemUtils:
|
||||
def is_bluray_dir(dir_path: Path) -> bool:
|
||||
"""
|
||||
判断是否为蓝光原盘目录
|
||||
|
||||
(该方法已弃用,改用`StorageChain().is_bluray_folder)`
|
||||
"""
|
||||
if not dir_path.is_dir():
|
||||
return False
|
||||
|
||||
41
database/versions/41ef1dd7467c_2_2_2.py
Normal file
41
database/versions/41ef1dd7467c_2_2_2.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""2.2.2
|
||||
|
||||
Revision ID: 41ef1dd7467c
|
||||
Revises: a946dae52526
|
||||
Create Date: 2026-01-13 13:02:41.614029
|
||||
|
||||
"""
|
||||
|
||||
from app.db import ScopedSession
|
||||
from app.db.models.systemconfig import SystemConfig
|
||||
from app.log import logger
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "41ef1dd7467c"
|
||||
down_revision = "a946dae52526"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# systemconfig表 去重
|
||||
with ScopedSession() as db:
|
||||
try:
|
||||
seen_keys = set()
|
||||
# 按ID降序查询,以便保留最新的配置
|
||||
for item in db.query(SystemConfig).order_by(SystemConfig.id.desc()).all():
|
||||
if item.key in seen_keys:
|
||||
logger.warn(
|
||||
f"已删除重复的SystemConfig项:{item.key} 值:{item.value}"
|
||||
)
|
||||
db.delete(item)
|
||||
else:
|
||||
seen_keys.add(item.key)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
db.rollback()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
30
database/versions/58edfac72c32_2_2_3.py
Normal file
30
database/versions/58edfac72c32_2_2_3.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""2.2.3
|
||||
添加 downloadhistory.custom_words 字段,用于整理时应用订阅识别词
|
||||
|
||||
Revision ID: 58edfac72c32
|
||||
Revises: 41ef1dd7467c
|
||||
Create Date: 2026-01-19
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "58edfac72c32"
|
||||
down_revision = "41ef1dd7467c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
# 检查并添加 downloadhistory.custom_words
|
||||
dh_columns = inspector.get_columns('downloadhistory')
|
||||
if not any(c['name'] == 'custom_words' for c in dh_columns):
|
||||
op.add_column('downloadhistory', sa.Column('custom_words', sa.String, nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# 降级时删除字段
|
||||
op.drop_column('downloadhistory', 'custom_words')
|
||||
@@ -1,161 +1,99 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# 文件列表结构 list[tuple(名称, 子文件列表 或 文件大小)]
|
||||
bluray_files = [
|
||||
{
|
||||
"name": "FOLDER",
|
||||
"children": [
|
||||
{
|
||||
"name": "Digimon",
|
||||
"children": [
|
||||
{
|
||||
"name": "Digimon (2055)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
(
|
||||
"FOLDER",
|
||||
[
|
||||
(
|
||||
"Digimon",
|
||||
[
|
||||
(
|
||||
"Digimon BluRay (2055)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Digimon (2099)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00002.m2ts.!qB",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Digimon BluRay (2099)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
("00002.m2ts.!qB", 104857600),
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Digimon (2199)",
|
||||
"children": [
|
||||
{
|
||||
"name": "Digimon.2199.mp4",
|
||||
"size": 104857600,
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
("Digimon (2199)", [("Digimon.2199.mp4", 104857600)]),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2016)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Pokemon BluRay (2016)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2021)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Pokemon BluRay (2021)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2028)",
|
||||
"children": [
|
||||
{
|
||||
"name": "Pokemon.2028.mkv",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "Pokemon.2028.hdr.mkv.!qB",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Pokemon (2028)",
|
||||
[
|
||||
("Pokemon.2028.mkv", 104857600),
|
||||
("Pokemon.2028.hdr.mkv.!qB", 104857600),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon.2029.mp4",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2030)",
|
||||
"children": [
|
||||
{
|
||||
"name": "S",
|
||||
"size": 104857600,
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
("Pokemon.2029.mp4", 104857600),
|
||||
("Pokemon.2039.mp4", 104857600),
|
||||
("Pokemon (2030)", [("S", 104857600)]),
|
||||
],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1117,4 +1117,19 @@ meta_cases = [{
|
||||
"audio_codec": "",
|
||||
"tmdbid": 19995
|
||||
}
|
||||
}, {
|
||||
"path": "/movies/DouBan_IMDB.TOP250.Movies.Mixed.Collection.20240501.FRDS/为奴十二年.12.Years.a.Slave.2013.BluRay.1080p.x265.10bit.2Audio.MNHD-FRDS/12.Years.a.Slave.2013.BluRay.1080p.x265.10bit.2Audio.MNHD-FRDS.mkv",
|
||||
"target": {
|
||||
"type": "未知",
|
||||
"cn_name": "",
|
||||
"en_name": "12 Years A Slave",
|
||||
"year": "2013",
|
||||
"part": "",
|
||||
"season": "",
|
||||
"episode": "",
|
||||
"restype": "BluRay",
|
||||
"pix": "1080p",
|
||||
"video_codec": "x265 10bit",
|
||||
"audio_codec": "2Audio"
|
||||
}
|
||||
}]
|
||||
|
||||
10
tests/run.py
10
tests/run.py
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
|
||||
from tests.test_bluray import BluRayTest
|
||||
from tests.test_metainfo import MetaInfoTest
|
||||
from tests.test_object import ObjectUtilsTest
|
||||
|
||||
@@ -12,6 +13,15 @@ if __name__ == '__main__':
|
||||
suite.addTest(MetaInfoTest('test_emby_format_ids'))
|
||||
suite.addTest(ObjectUtilsTest('test_check_method'))
|
||||
|
||||
# 测试自定义识别词功能
|
||||
suite.addTest(MetaInfoTest('test_metainfopath_with_custom_words'))
|
||||
suite.addTest(MetaInfoTest('test_metainfopath_without_custom_words'))
|
||||
suite.addTest(MetaInfoTest('test_metainfopath_with_empty_custom_words'))
|
||||
suite.addTest(MetaInfoTest('test_custom_words_apply_words_recording'))
|
||||
|
||||
# 测试蓝光目录识别
|
||||
suite.addTest(BluRayTest())
|
||||
|
||||
# 运行测试
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite)
|
||||
|
||||
226
tests/test_bluray.py
Normal file
226
tests/test_bluray.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.storage import StorageChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import Event
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
from tests.cases.files import bluray_files
|
||||
|
||||
|
||||
class BluRayTest(TestCase):
|
||||
def __init__(self, methodName="test"):
|
||||
super().__init__(methodName)
|
||||
self.__history = []
|
||||
self.__root = schemas.FileItem(
|
||||
path="/", name="", type="dir", extension="", size=0
|
||||
)
|
||||
self.__all = {self.__root.path: self.__root}
|
||||
|
||||
def __build_child(parent: schemas.FileItem, files: list[tuple[str, list | int]]):
|
||||
parent.children = []
|
||||
for name, children in files:
|
||||
sep = "" if parent.path.endswith("/") else "/"
|
||||
file_item = schemas.FileItem(
|
||||
path=f"{parent.path}{sep}{name}",
|
||||
name=name,
|
||||
extension=Path(name).suffix[1:],
|
||||
basename=Path(name).stem,
|
||||
type="file" if isinstance(children, int) else "dir",
|
||||
size=children if isinstance(children, int) else 0,
|
||||
)
|
||||
parent.children.append(file_item)
|
||||
self.__all[file_item.path] = file_item
|
||||
if isinstance(children, list):
|
||||
__build_child(file_item, children)
|
||||
|
||||
__build_child(self.__root, bluray_files)
|
||||
|
||||
def _test_do_transfer(self):
|
||||
def __test_do_transfer(path: str):
|
||||
self.__history.clear()
|
||||
TransferChain().do_transfer(
|
||||
force=False,
|
||||
background=False,
|
||||
fileitem=StorageChain().get_file_item(None, Path(path)),
|
||||
)
|
||||
return self.__history
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
"/FOLDER/Digimon/Digimon BluRay (2099)",
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon BluRay (2055)"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon BluRay (2055)/BDMV"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon BluRay (2055)/BDMV/STREAM"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer(
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)/BDMV/STREAM/00001.m2ts"
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon (2199)"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Pokemon.2029.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Pokemon.2029.mp4"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
"/FOLDER/Digimon/Digimon BluRay (2099)",
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
"/FOLDER/Pokemon BluRay (2016)",
|
||||
"/FOLDER/Pokemon BluRay (2021)",
|
||||
"/FOLDER/Pokemon (2028)/Pokemon.2028.mkv",
|
||||
"/FOLDER/Pokemon.2029.mp4",
|
||||
"/FOLDER/Pokemon.2039.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER"),
|
||||
)
|
||||
|
||||
def _test_scrape_metadata(self, mock_metadata_nfo):
|
||||
def __test_scrape_metadata(path: str, excepted_nfo_count: int = 1):
|
||||
"""
|
||||
分别测试手动和自动刮削
|
||||
"""
|
||||
fileitem = StorageChain().get_file_item(None, Path(path))
|
||||
meta = MetaInfoPath(Path(fileitem.path))
|
||||
mediainfo = MediaInfo(tmdb_info={"id": 1, "title": "Test"})
|
||||
|
||||
# 测试手动刮削
|
||||
logger.debug(f"测试手动刮削 {path}")
|
||||
mock_metadata_nfo.call_count = 0
|
||||
MediaChain().scrape_metadata(
|
||||
fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True
|
||||
)
|
||||
# 确保调用了指定次数的metadata_nfo
|
||||
self.assertEqual(mock_metadata_nfo.call_count, excepted_nfo_count)
|
||||
|
||||
# 测试自动刮削
|
||||
logger.debug(f"测试自动刮削 {path}")
|
||||
mock_metadata_nfo.call_count = 0
|
||||
MediaChain().scrape_metadata_event(
|
||||
Event(
|
||||
event_type=EventType.MetadataScrape,
|
||||
event_data={
|
||||
"meta": meta,
|
||||
"mediainfo": mediainfo,
|
||||
"fileitem": fileitem,
|
||||
"file_list": [fileitem.path],
|
||||
"overwrite": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
# 调用了指定次数的metadata_nfo
|
||||
self.assertEqual(mock_metadata_nfo.call_count, excepted_nfo_count)
|
||||
|
||||
# 刮削原盘目录
|
||||
__test_scrape_metadata("/FOLDER/Digimon/Digimon BluRay (2099)")
|
||||
# 刮削电影文件
|
||||
__test_scrape_metadata("/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4")
|
||||
# 刮削电影目录
|
||||
__test_scrape_metadata("/FOLDER", excepted_nfo_count=2)
|
||||
|
||||
@patch("app.chain.ChainBase.metadata_img", return_value=None) # 避免获取图片
|
||||
@patch("app.chain.ChainBase.__init__", return_value=None) # 避免不必要的模块初始化
|
||||
@patch("app.db.transferhistory_oper.TransferHistoryOper.get_by_src")
|
||||
@patch("app.chain.storage.StorageChain.list_files")
|
||||
@patch("app.chain.storage.StorageChain.get_parent_item")
|
||||
@patch("app.chain.storage.StorageChain.get_file_item")
|
||||
def test(
|
||||
self,
|
||||
mock_get_file_item,
|
||||
mock_get_parent_item,
|
||||
mock_list_files,
|
||||
mock_get_by_src,
|
||||
*_,
|
||||
):
|
||||
def get_file_item(storage: str, path: Path):
|
||||
path_posix = path.as_posix()
|
||||
return self.__all.get(path_posix)
|
||||
|
||||
def get_parent_item(fileitem: schemas.FileItem):
|
||||
return get_file_item(None, Path(fileitem.path).parent)
|
||||
|
||||
def list_files(fileitem: schemas.FileItem, recursion: bool = False):
|
||||
if fileitem.type != "dir":
|
||||
return None
|
||||
if recursion:
|
||||
result = []
|
||||
file_path = f"{fileitem.path}/"
|
||||
for path, item in self.__all.items():
|
||||
if path.startswith(file_path):
|
||||
result.append(item)
|
||||
return result
|
||||
else:
|
||||
return fileitem.children
|
||||
|
||||
def get_by_src(src: str, storage: Optional[str] = None):
|
||||
self.__history.append(src)
|
||||
result = TransferHistory()
|
||||
result.status = True
|
||||
return result
|
||||
|
||||
mock_get_file_item.side_effect = get_file_item
|
||||
mock_get_parent_item.side_effect = get_parent_item
|
||||
mock_list_files.side_effect = list_files
|
||||
mock_get_by_src.side_effect = get_by_src
|
||||
|
||||
self._test_do_transfer()
|
||||
|
||||
with patch(
|
||||
"app.chain.media.MediaChain.metadata_nfo", return_value=None
|
||||
) as mock:
|
||||
self._test_scrape_metadata(mock_metadata_nfo=mock)
|
||||
@@ -61,3 +61,38 @@ class MetaInfoTest(TestCase):
|
||||
meta = MetaInfoPath(Path(path_str))
|
||||
self.assertEqual(meta.tmdbid, expected_tmdbid,
|
||||
f"路径 {path_str} 期望的tmdbid为 {expected_tmdbid},实际识别为 {meta.tmdbid}")
|
||||
|
||||
def test_metainfopath_with_custom_words(self):
|
||||
"""测试 MetaInfoPath 使用自定义识别词"""
|
||||
# 测试替换词:将"测试替换"替换为空
|
||||
custom_words = ["测试替换 => "]
|
||||
path = Path("/movies/电影测试替换名称 (2024)/movie.mkv")
|
||||
meta = MetaInfoPath(path, custom_words=custom_words)
|
||||
# 验证替换生效:cn_name 不应包含"测试替换"
|
||||
if meta.cn_name:
|
||||
self.assertNotIn("测试替换", meta.cn_name)
|
||||
|
||||
def test_metainfopath_without_custom_words(self):
|
||||
"""测试 MetaInfoPath 不传入自定义识别词"""
|
||||
path = Path("/movies/Normal Movie (2024)/movie.mkv")
|
||||
meta = MetaInfoPath(path)
|
||||
# 验证正常识别,不报错
|
||||
self.assertIsNotNone(meta)
|
||||
|
||||
def test_metainfopath_with_empty_custom_words(self):
|
||||
"""测试 MetaInfoPath 传入空的自定义识别词"""
|
||||
path = Path("/movies/Test Movie (2024)/movie.mkv")
|
||||
meta = MetaInfoPath(path, custom_words=[])
|
||||
# 验证不报错,正常识别
|
||||
self.assertIsNotNone(meta)
|
||||
|
||||
def test_custom_words_apply_words_recording(self):
|
||||
"""测试 apply_words 记录功能"""
|
||||
custom_words = ["替换词 => 新词"]
|
||||
title = "电影替换词.2024.mkv"
|
||||
meta = MetaInfo(title=title, custom_words=custom_words)
|
||||
# 验证 apply_words 属性存在
|
||||
self.assertTrue(hasattr(meta, 'apply_words'))
|
||||
# 如果替换词被应用,应该记录在 apply_words 中
|
||||
if meta.apply_words:
|
||||
self.assertIn("替换词 => 新词", meta.apply_words)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.9.1'
|
||||
FRONTEND_VERSION = 'v2.9.1'
|
||||
APP_VERSION = 'v2.9.4'
|
||||
FRONTEND_VERSION = 'v2.9.4'
|
||||
|
||||
Reference in New Issue
Block a user