mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 08:22:40 +08:00
Compare commits
79 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58592e961f | ||
|
|
9a99b9ce82 | ||
|
|
8c6dca1751 | ||
|
|
cf488d5f5f | ||
|
|
515584d34c | ||
|
|
fb2becc7f2 | ||
|
|
0f8ceb0fac | ||
|
|
a70bf18770 | ||
|
|
2de83c44ab | ||
|
|
7b99f09810 | ||
|
|
6b4ba8bfad | ||
|
|
0c6cfc5020 | ||
|
|
abd9733e7f | ||
|
|
98c3ae5e76 | ||
|
|
bb5a657469 | ||
|
|
7797532350 | ||
|
|
c3a5106adc | ||
|
|
c5fd935dd0 | ||
|
|
ec375a19ae | ||
|
|
51e940617c | ||
|
|
58ec8bd437 | ||
|
|
a096395086 | ||
|
|
4bd08bd915 | ||
|
|
2c849cfa7a | ||
|
|
501d530d1d | ||
|
|
91fc4327f4 | ||
|
|
8d56c67079 | ||
|
|
e52d43458e | ||
|
|
9b125bf9b0 | ||
|
|
0716c65269 | ||
|
|
ba3ce4f1b5 | ||
|
|
07f72b0cdc | ||
|
|
bda19df87f | ||
|
|
5d82fae2b0 | ||
|
|
0813b87221 | ||
|
|
961ecfc720 | ||
|
|
81f30ef25a | ||
|
|
140b0d3df2 | ||
|
|
b3d69d7de4 | ||
|
|
8e65564fb8 | ||
|
|
06ce9bd4de | ||
|
|
274fc2d74f | ||
|
|
2f1a448afe | ||
|
|
99cab7c337 | ||
|
|
81f7548579 | ||
|
|
6ebd50bebc | ||
|
|
378ba51f4d | ||
|
|
63a890e85d | ||
|
|
bf4f9921e2 | ||
|
|
167ae65695 | ||
|
|
2affa7c9b8 | ||
|
|
785540e178 | ||
|
|
bcad4c0bc6 | ||
|
|
5af217fbf5 | ||
|
|
128aa2ef23 | ||
|
|
fce1186dd1 | ||
|
|
9a7b11f804 | ||
|
|
b068a06fa8 | ||
|
|
931a42e981 | ||
|
|
e0a20a6697 | ||
|
|
1ef4374899 | ||
|
|
3b7212740b | ||
|
|
4b80b8dc1f | ||
|
|
b7f24827e6 | ||
|
|
1c08a22881 | ||
|
|
8bd848519d | ||
|
|
e19f2aa76d | ||
|
|
4a99e2896f | ||
|
|
de3c83b0aa | ||
|
|
36bdb831be | ||
|
|
1809690915 | ||
|
|
e51b679380 | ||
|
|
10c26de7cb | ||
|
|
ca5ec8af0f | ||
|
|
d1d7b8ce55 | ||
|
|
77f8983307 | ||
|
|
ba415acd37 | ||
|
|
bcf13099ac | ||
|
|
eb2b34d71c |
@@ -7,7 +7,7 @@ from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
@@ -56,9 +56,6 @@ class MoviePilotAgent:
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 会话存储
|
||||
self.session_store = self._initialize_session_store()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
@@ -127,7 +124,8 @@ class MoviePilotAgent:
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
callback_handler=self.callback_handler,
|
||||
memory_mananger=self.memory_manager
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -137,34 +135,36 @@ class MoviePilotAgent:
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
if session_id not in self.session_store:
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_ai_message(AIMessage(
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)]
|
||||
))
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
self.session_store[session_id] = chat_history
|
||||
return self.session_store[session_id]
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_message(ToolMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
@@ -306,8 +306,6 @@ class MoviePilotAgent:
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
if self.session_id in self.session_store:
|
||||
del self.session_store[self.session_id]
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
|
||||
@@ -45,17 +45,27 @@ class ConversationMemoryManager:
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
def get_memory_key(session_id: str, user_id: str):
|
||||
"""计算内存Key"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def get_redis_key(session_id: str, user_id: str):
|
||||
"""计算Redis Key"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
@@ -180,15 +190,13 @@ class ConversationMemoryManager:
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
|
||||
return messages
|
||||
return memory.messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
@@ -218,7 +226,7 @@ class ConversationMemoryManager:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
@@ -229,14 +237,14 @@ class ConversationMemoryManager:
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -24,6 +25,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -35,24 +37,53 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送运行工具前的消息
|
||||
# 发送和记忆工具调用前的信息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
# 发送消息
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
|
||||
# 记忆工具调用
|
||||
await self._memory_manager.add_memory(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": self.__class__.__name__,
|
||||
"tool_name": self.__class__.__name__,
|
||||
"parameters": kwargs
|
||||
}
|
||||
)
|
||||
|
||||
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
|
||||
# 记忆工具调用结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
await self._memory_manager.add_memory(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -84,6 +115,10 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""设置回调处理器"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
|
||||
"""设置记忆客理器"""
|
||||
self._memory_manager = memory_manager
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
await ToolChain().async_post_message(
|
||||
|
||||
@@ -51,7 +51,7 @@ class MoviePilotToolFactory:
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
@@ -102,6 +102,7 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -124,6 +125,7 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
|
||||
@@ -25,7 +25,7 @@ class AddDownloadInput(BaseModel):
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
|
||||
description="Directory path where the downloaded files should be saved. Using `<storage>:<path>` for remote storage. e.g. rclone:/MP, smb:/server/share/Movies. (optional, uses default path if not specified)")
|
||||
labels: Optional[str] = Field(None,
|
||||
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -9,6 +9,8 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus
|
||||
|
||||
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
@@ -27,6 +29,27 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
def _get_all_torrents(self, download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
|
||||
"""
|
||||
查询所有状态的任务(包括下载中和已完成的任务)
|
||||
"""
|
||||
all_torrents = []
|
||||
# 查询正在下载的任务
|
||||
downloading_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.DOWNLOADING
|
||||
) or []
|
||||
all_torrents.extend(downloading_torrents)
|
||||
|
||||
# 查询已完成的任务(可转移状态)
|
||||
transfer_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.TRANSFER
|
||||
) or []
|
||||
all_torrents.extend(transfer_torrents)
|
||||
|
||||
return all_torrents
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
downloader = kwargs.get("downloader")
|
||||
@@ -60,7 +83,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
|
||||
# 如果提供了hash,直接查询该hash的任务(不限制状态)
|
||||
if hash:
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash])
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash]) or []
|
||||
if not torrents:
|
||||
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
|
||||
# 转换为DownloadingTorrent格式
|
||||
@@ -84,14 +107,25 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
elif title:
|
||||
# 如果提供了title,查询所有任务并搜索匹配的标题
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
title_lower = title.lower()
|
||||
for torrent in all_torrents:
|
||||
# 检查标题或名称是否匹配
|
||||
if (title.lower() in (torrent.title or "").lower()) or \
|
||||
(title.lower() in (torrent.name or "").lower()):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
|
||||
# 检查标题或名称是否匹配(包括下载历史中的标题)
|
||||
matched = False
|
||||
# 检查torrent的title和name字段
|
||||
if (title_lower in (torrent.title or "").lower()) or \
|
||||
(title_lower in (torrent.name or "").lower()):
|
||||
matched = True
|
||||
# 检查下载历史中的标题
|
||||
if history and history.title:
|
||||
if title_lower in history.title.lower():
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
@@ -110,7 +144,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
# 根据status决定查询方式
|
||||
if status == "downloading":
|
||||
# 如果status为下载中,使用downloading方法
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
downloads = download_chain.downloading(name=downloader) or []
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
@@ -119,7 +153,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
else:
|
||||
# 其他状态(completed、paused、all),使用list_torrents查询所有任务
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
if downloader and torrent.downloader != downloader:
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent import ConversationMemoryManager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
@@ -21,7 +23,7 @@ class ToolDefinition:
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = "api_session"):
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
@@ -32,6 +34,7 @@ class MoviePilotToolsManager:
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
@@ -44,7 +47,8 @@ class MoviePilotToolsManager:
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None
|
||||
callback_handler=None,
|
||||
memory_mananger=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -96,6 +100,73 @@ class MoviePilotToolsManager:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据工具的参数schema规范化参数类型
|
||||
|
||||
Args:
|
||||
tool_instance: 工具实例
|
||||
arguments: 原始参数
|
||||
|
||||
Returns:
|
||||
规范化后的参数
|
||||
"""
|
||||
# 获取工具的参数schema
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
properties = schema.get("properties", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
if key not in properties:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
field_info = properties[key]
|
||||
field_type = field_info.get("type")
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
# 从 anyOf 中提取实际类型
|
||||
for type_option in any_of:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
|
||||
normalized[key] = value
|
||||
elif field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
|
||||
normalized[key] = value
|
||||
elif field_type == "boolean" and isinstance(value, str):
|
||||
# 转换字符串为布尔值
|
||||
normalized[key] = value.lower() in ("true", "1", "yes", "on")
|
||||
else:
|
||||
# 其他类型保持原样
|
||||
normalized[key] = value
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
@@ -116,14 +187,21 @@ class MoviePilotToolsManager:
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**arguments)
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
|
||||
@@ -2,11 +2,12 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
api_router.include_router(user.router, prefix="/user", tags=["user"])
|
||||
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
|
||||
api_router.include_router(site.router, prefix="/site", tags=["site"])
|
||||
api_router.include_router(message.router, prefix="/message", tags=["message"])
|
||||
api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"])
|
||||
|
||||
@@ -6,12 +6,13 @@ from app import schemas
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import MediaInfo, Context, TorrentInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.schemas.types import ChainEventType, SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -67,6 +68,7 @@ def add(
|
||||
tmdbid: Annotated[int | None, Body()] = None,
|
||||
doubanid: Annotated[str | None, Body()] = None,
|
||||
downloader: Annotated[str | None, Body()] = None,
|
||||
# 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
save_path: Annotated[str | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
"""
|
||||
@@ -77,7 +79,11 @@ def add(
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
mediainfo = MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
@@ -87,6 +93,7 @@ def add(
|
||||
media_info=mediainfo,
|
||||
torrent_info=torrentinfo
|
||||
)
|
||||
|
||||
did = DownloadChain().download_single(context=context, username=current_user.name,
|
||||
downloader=downloader, save_path=save_path, source="Manual")
|
||||
if not did:
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.wallpaper import WallpaperHelper
|
||||
from app.helper.image import WallpaperHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
@@ -29,6 +29,13 @@ def login_access_token(
|
||||
mfa_code=otp_password)
|
||||
|
||||
if not success:
|
||||
# 如果是需要MFA验证,返回特殊标识
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
|
||||
# 用户等级
|
||||
@@ -50,7 +57,7 @@ def login_access_token(
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions=user_or_message.permissions or {},
|
||||
widzard=show_wizard
|
||||
wizard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,43 +2,251 @@
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict, Annotated
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
# 导入版本号
|
||||
try:
|
||||
from version import APP_VERSION
|
||||
except ImportError:
|
||||
APP_VERSION = "unknown"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局工具管理器实例(单例模式,按用户ID缓存)
|
||||
_tools_managers: Dict[str, MoviePilotToolsManager] = {}
|
||||
# MCP 协议版本
|
||||
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
|
||||
def get_tools_manager(user_id: str = "mcp_user", session_id: str = "mcp_session") -> MoviePilotToolsManager:
|
||||
def get_tools_manager() -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例(按用户ID缓存)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
|
||||
获取工具管理器实例
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
global _tools_managers
|
||||
# 使用用户ID作为缓存键
|
||||
cache_key = f"{user_id}_{session_id}"
|
||||
if cache_key not in _tools_managers:
|
||||
_tools_managers[cache_key] = MoviePilotToolsManager(
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
return _tools_managers[cache_key]
|
||||
return MoviePilotToolsManager()
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 成功响应"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": result
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 错误响应"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message
|
||||
}
|
||||
}
|
||||
if data is not None:
|
||||
error["error"]["data"] = data
|
||||
return error
|
||||
|
||||
|
||||
# ==================== MCP JSON-RPC 端点 ====================
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
MCP 标准 JSON-RPC 2.0 端点
|
||||
|
||||
处理所有 MCP 协议消息(初始化、工具列表、工具调用等)
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
logger.error(f"解析请求体失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(None, -32700, "Parse error", str(e))
|
||||
)
|
||||
|
||||
# 验证 JSON-RPC 格式
|
||||
if not isinstance(body, dict) or body.get("jsonrpc") != "2.0":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(body.get("id"), -32600, "Invalid Request")
|
||||
)
|
||||
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
# 如果有 id,则为请求;没有 id 则为通知
|
||||
is_notification = request_id is None
|
||||
|
||||
try:
|
||||
# 处理初始化请求
|
||||
if method == "initialize":
|
||||
result = await handle_initialize(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理已初始化通知
|
||||
elif method == "notifications/initialized":
|
||||
if is_notification:
|
||||
return Response(status_code=204)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "initialized must be a notification"}
|
||||
)
|
||||
|
||||
# 处理工具列表请求
|
||||
if method == "tools/list":
|
||||
result = await handle_tools_list()
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理工具调用请求
|
||||
elif method == "tools/call":
|
||||
result = await handle_tools_call(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理 ping 请求
|
||||
elif method == "ping":
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, {}))
|
||||
|
||||
# 未知方法
|
||||
else:
|
||||
return JSONResponse(
|
||||
content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}")
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"MCP 请求参数错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MCP 请求失败: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_jsonrpc_error(request_id, -32603, "Internal error", str(e))
|
||||
)
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理初始化请求"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}")
|
||||
|
||||
# 版本协商:选择客户端和服务器都支持的版本
|
||||
negotiated_version = MCP_PROTOCOL_VERSION
|
||||
if protocol_version in MCP_PROTOCOL_VERSIONS:
|
||||
# 客户端版本在支持列表中,使用客户端版本
|
||||
negotiated_version = protocol_version
|
||||
logger.info(f"使用客户端协议版本: {negotiated_version}")
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": False # 暂不支持工具列表变更通知
|
||||
},
|
||||
"logging": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "MoviePilot",
|
||||
"version": APP_VERSION,
|
||||
"description": "MoviePilot MCP Server - 电影自动化管理工具",
|
||||
},
|
||||
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。"
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""处理工具列表请求"""
|
||||
manager = get_tools_manager()
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
for tool in tools:
|
||||
mcp_tool = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
mcp_tools.append(mcp_tool)
|
||||
|
||||
return {
|
||||
"tools": mcp_tools
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理工具调用请求"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
manager = get_tools_manager()
|
||||
|
||||
try:
|
||||
result_text = await manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result_text
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"工具调用失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误: {str(e)}"
|
||||
}
|
||||
],
|
||||
"isError": True
|
||||
}
|
||||
|
||||
|
||||
@router.delete("", summary="终止 MCP 会话", response_model=None)
|
||||
async def delete_mcp_session(
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
终止 MCP 会话(无状态模式下仅返回成功)
|
||||
"""
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
async def list_tools(
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
@@ -72,7 +280,7 @@ async def list_tools(
|
||||
@router.post("/tools/call", summary="调用工具", response_model=schemas.ToolCallResponse)
|
||||
async def call_tool(
|
||||
request: schemas.ToolCallRequest,
|
||||
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Any:
|
||||
"""
|
||||
调用指定的工具
|
||||
|
||||
463
app/api/endpoints/mfa.py
Normal file
463
app/api/endpoints/mfa.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""
|
||||
MFA (Multi-Factor Authentication) API 端点
|
||||
包含 OTP 和 PassKey 相关功能
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db import get_async_db
|
||||
from app.db.models.passkey import PassKey
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_user_async
|
||||
from app.helper.passkey import PassKeyHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
"""OTP验证请求"""
|
||||
uri: str
|
||||
otpPassword: str
|
||||
|
||||
class OtpDisableRequest(schemas.BaseModel):
|
||||
"""OTP禁用请求"""
|
||||
password: str
|
||||
|
||||
class PassKeyDeleteRequest(schemas.BaseModel):
|
||||
"""PassKey删除请求"""
|
||||
passkey_id: int
|
||||
password: str
|
||||
|
||||
# ==================== 通用 MFA 接口 ====================
|
||||
|
||||
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
|
||||
async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
检查指定用户是否启用了任何双重验证方式(OTP 或 PassKey)
|
||||
"""
|
||||
user: User = await User.async_get_by_name(db, username)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
|
||||
# 检查是否启用了OTP
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id))
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
|
||||
|
||||
# ==================== OTP 相关接口 ====================
|
||||
|
||||
@router.post('/otp/generate', summary='生成 OTP 验证 URI', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""生成 OTP 密钥及对应的 URI"""
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
|
||||
async def otp_verify(
|
||||
data: OtpVerifyRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
|
||||
if not OtpUtils.is_legal(data.uri, data.otpPassword):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
data: OtpDisableRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,不允许关闭 OTP
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id))
|
||||
if has_passkey:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
# ==================== PassKey 相关接口 ====================
|
||||
|
||||
class PassKeyRegistrationStart(schemas.BaseModel):
|
||||
"""PassKey注册开始请求"""
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyRegistrationFinish(schemas.BaseModel):
|
||||
"""PassKey注册完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyAuthenticationStart(schemas.BaseModel):
|
||||
"""PassKey认证开始请求"""
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class PassKeyAuthenticationFinish(schemas.BaseModel):
|
||||
"""PassKey认证完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
|
||||
|
||||
@router.post("/passkey/register/start", summary="开始注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_start(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""开始注册 PassKey - 生成注册选项"""
|
||||
try:
|
||||
# 安全检查:必须先启用 OTP
|
||||
if not current_user.is_otp:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥"
|
||||
)
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
] if existing_passkeys else None
|
||||
|
||||
# 生成注册选项
|
||||
options_json, challenge = PassKeyHelper.generate_registration_options(
|
||||
user_id=current_user.id,
|
||||
username=current_user.name,
|
||||
display_name=current_user.settings.get('nickname') if current_user.settings else None,
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey注册选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成注册选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/register/finish", summary="完成注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_finish(
|
||||
passkey_req: PassKeyRegistrationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""完成注册 PassKey - 验证并保存凭证"""
|
||||
try:
|
||||
# 验证注册响应
|
||||
credential_id, public_key, sign_count, aaguid = PassKeyHelper.verify_registration_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge
|
||||
)
|
||||
|
||||
# 提取transports
|
||||
transports = None
|
||||
if 'response' in passkey_req.credential and 'transports' in passkey_req.credential['response']:
|
||||
transports = ','.join(passkey_req.credential['response']['transports'])
|
||||
|
||||
# 保存到数据库
|
||||
passkey = PassKey(
|
||||
user_id=current_user.id,
|
||||
credential_id=credential_id,
|
||||
public_key=public_key,
|
||||
sign_count=sign_count,
|
||||
name=passkey_req.name or "通行密钥",
|
||||
aaguid=aaguid,
|
||||
transports=transports
|
||||
)
|
||||
passkey.create()
|
||||
|
||||
logger.info(f"用户 {current_user.name} 成功注册PassKey: {passkey_req.name}")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥注册成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"注册PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"注册失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/start", summary="开始 PassKey 认证", response_model=schemas.Response)
|
||||
def passkey_authenticate_start(
|
||||
passkey_req: PassKeyAuthenticationStart = Body(...)
|
||||
) -> Any:
|
||||
"""开始 PassKey 认证 - 生成认证选项"""
|
||||
try:
|
||||
existing_credentials = None
|
||||
|
||||
# 如果指定了用户名,只允许该用户的PassKey
|
||||
if passkey_req.username:
|
||||
user = User.get_by_name(db=None, name=passkey_req.username)
|
||||
if not user:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="用户不存在"
|
||||
)
|
||||
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id)
|
||||
if not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="该用户未注册通行密钥"
|
||||
)
|
||||
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
]
|
||||
|
||||
# 生成认证选项
|
||||
options_json, challenge = PassKeyHelper.generate_authentication_options(
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey认证选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成认证选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/finish", summary="完成 PassKey 认证", response_model=schemas.Token)
|
||||
def passkey_authenticate_finish(
|
||||
passkey_req: PassKeyAuthenticationFinish
|
||||
) -> Any:
|
||||
"""完成 PassKey 认证 - 验证凭证并返回 token"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise HTTPException(status_code=400, detail="无效的凭证")
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey:
|
||||
raise HTTPException(status_code=401, detail="通行密钥不存在或已失效")
|
||||
|
||||
# 获取用户
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="用户不存在或已禁用")
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail="通行密钥验证失败")
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
# 生成token
|
||||
level = SitesHelper().auth_level
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
level=level
|
||||
),
|
||||
token_type="bearer",
|
||||
super_user=user.is_superuser,
|
||||
user_id=user.id,
|
||||
user_name=user.name,
|
||||
avatar=user.avatar,
|
||||
level=level,
|
||||
permissions=user.permissions or {},
|
||||
wizard=show_wizard
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey认证失败: {e}")
|
||||
raise HTTPException(status_code=401, detail=f"认证失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
|
||||
def passkey_list(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""获取当前用户的所有 PassKey"""
|
||||
try:
|
||||
passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
|
||||
key_list = [
|
||||
{
|
||||
'id': pk.id,
|
||||
'name': pk.name,
|
||||
'created_at': pk.created_at.isoformat() if pk.created_at else None,
|
||||
'last_used_at': pk.last_used_at.isoformat() if pk.last_used_at else None,
|
||||
'aaguid': pk.aaguid,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=key_list
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取PassKey列表失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"获取列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
|
||||
async def passkey_delete(
|
||||
data: PassKeyDeleteRequest,
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""删除指定的 PassKey"""
|
||||
try:
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
|
||||
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥已删除"
|
||||
)
|
||||
else:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或无权删除"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"删除PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"删除失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/verify", summary="PassKey 二次验证", response_model=schemas.Response)
|
||||
def passkey_verify_mfa(
|
||||
passkey_req: PassKeyAuthenticationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""使用 PassKey 进行二次验证(MFA)"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="无效的凭证"
|
||||
)
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey or passkey.user_id != current_user.id:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或不属于当前用户"
|
||||
)
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if not success:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥验证失败"
|
||||
)
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="二次验证成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey二次验证失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"验证失败: {str(e)}"
|
||||
)
|
||||
@@ -1,15 +1,12 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Annotated
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
@@ -19,7 +16,6 @@ from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.cache import AsyncFileCache
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -29,12 +25,14 @@ from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.system import SystemHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas import ConfigChangeEventData
|
||||
@@ -44,14 +42,13 @@ from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.url import UrlUtils
|
||||
from version import APP_VERSION
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: bool = False,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
@@ -70,77 +67,24 @@ async def fetch_image(
|
||||
logger.warn(f"Blocked unsafe image URL: {url}")
|
||||
return None
|
||||
|
||||
# 缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path("images") / sanitized_path
|
||||
if not cache_path.suffix:
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 缓存对像,缓存过期时间为全局图片缓存天数
|
||||
cache_backend = AsyncFileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
if use_cache:
|
||||
content = await cache_backend.get(cache_path.as_posix(), region="images")
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if proxy else None
|
||||
response = await AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=proxies,
|
||||
referer=referer,
|
||||
content = await ImageHelper().async_fetch_image(
|
||||
url=url,
|
||||
proxy=proxy,
|
||||
use_cache=use_cache,
|
||||
cookies=cookies,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*",
|
||||
).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
content = response.content
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
except Exception as e:
|
||||
logger.warn(f"Invalid image format for URL {url}: {e}")
|
||||
return None
|
||||
|
||||
# 获取请求响应头
|
||||
response_headers = response.headers
|
||||
cache_control_header = response_headers.get("Cache-Control", "")
|
||||
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
|
||||
|
||||
# 保存缓存
|
||||
if use_cache:
|
||||
await cache_backend.set(cache_path.as_posix(), content, region="images")
|
||||
logger.debug(f"Image cached at {cache_path.as_posix()}")
|
||||
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
if if_none_match == etag:
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(status_code=304, headers=headers)
|
||||
|
||||
# 响应
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
@@ -178,8 +122,7 @@ async def cache_img(
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
proxy = "doubanio.com" not in url
|
||||
return await fetch_image(url=url, proxy=proxy, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
|
||||
|
||||
@@ -191,11 +134,15 @@ def get_global_setting(token: str):
|
||||
if token != "moviepilot":
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
# 白名单模式,仅包含前端业务逻辑必需的字段
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
include={
|
||||
"TMDB_IMAGE_DOMAIN",
|
||||
"GLOBAL_IMAGE_CACHE",
|
||||
"ADVANCED_MODE",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE"
|
||||
}
|
||||
)
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
@@ -248,13 +195,11 @@ async def set_env_setting(env: dict,
|
||||
)
|
||||
|
||||
if success_updates:
|
||||
for key in success_updates.keys():
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=getattr(settings, key, None),
|
||||
change_type="update"
|
||||
))
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=success_updates.keys(),
|
||||
change_type="update"
|
||||
))
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
|
||||
@@ -111,45 +111,6 @@ async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db),
|
||||
return schemas.Response(success=True, message=file.filename)
|
||||
|
||||
|
||||
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||
async def otp_judge(
|
||||
data: dict,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
|
||||
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
user: User = await User.async_get_by_name(db, userid)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
return schemas.Response(success=user.is_otp)
|
||||
|
||||
|
||||
@router.get("/config/{key}", summary="查询用户配置", response_model=schemas.Response)
|
||||
def get_config(key: str,
|
||||
current_user: User = Depends(get_current_active_user)):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Callable, Any, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -128,9 +128,12 @@ async def get_cookie(
|
||||
@cookie_router.post("/get/{uuid}")
|
||||
async def post_cookie(
|
||||
uuid: Annotated[str, Path(min_length=5, pattern="^[a-zA-Z0-9]+$")],
|
||||
request: schemas.CookiePassword):
|
||||
request: Optional[schemas.CookiePassword] = Body(None)):
|
||||
"""
|
||||
POST 下载加密数据
|
||||
"""
|
||||
data = await load_encrypt_data(uuid)
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
if request is not None:
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
else:
|
||||
return data
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
from app.schemas import ExistMediaInfo, FileURI, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
ResourceDownloadEventData
|
||||
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ContentType, \
|
||||
ChainEventType
|
||||
@@ -162,7 +162,7 @@ class DownloadChain(ChainBase):
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、Subscribe、Manual等)
|
||||
:param downloader: 下载器
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param userid: 用户ID
|
||||
:param username: 调用下载的用户名/插件名
|
||||
:param label: 自定义标签
|
||||
@@ -232,13 +232,14 @@ class DownloadChain(ChainBase):
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
|
||||
storage = 'local'
|
||||
# 下载目录
|
||||
if save_path:
|
||||
# 下载目录使用自定义的
|
||||
download_dir = Path(save_path)
|
||||
else:
|
||||
# 根据媒体信息查询下载目录配置
|
||||
dir_info = DirectoryHelper().get_dir(_media, storage="local", include_unsorted=True)
|
||||
dir_info = DirectoryHelper().get_dir(_media, include_unsorted=True)
|
||||
storage = dir_info.storage if dir_info else storage
|
||||
# 拼装子目录
|
||||
if dir_info:
|
||||
# 一级目录
|
||||
@@ -259,6 +260,8 @@ class DownloadChain(ChainBase):
|
||||
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
|
||||
title="下载失败", role="system")
|
||||
return None
|
||||
fileURI = FileURI(storage=storage, path=download_dir.as_posix())
|
||||
download_dir = Path(fileURI.uri)
|
||||
|
||||
# 添加下载
|
||||
result: Optional[tuple] = self.download(content=torrent_content,
|
||||
@@ -400,7 +403,7 @@ class DownloadChain(ChainBase):
|
||||
根据缺失数据,自动种子列表中组合择优下载
|
||||
:param contexts: 资源上下文列表
|
||||
:param no_exists: 缺失的剧集信息
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、订阅、手工下载等)
|
||||
:param userid: 用户ID
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached, FileCache
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings, global_vars
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@@ -103,40 +99,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
请求并保存图片
|
||||
:param url: 图片路径
|
||||
"""
|
||||
# 生成缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path("images") / sanitized_path
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 获取缓存后端,并设置缓存时间为全局配置的缓存天数
|
||||
cache_backend = FileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
# 本地存在缓存图片,则直接跳过
|
||||
if cache_backend.get(cache_path.as_posix(), region="images"):
|
||||
logger.debug(f"Cache hit: Image already exists at {cache_path}")
|
||||
return
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if not referer else None
|
||||
response = RequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
|
||||
if not response:
|
||||
logger.debug(f"Empty response for URL: {url}")
|
||||
return
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
Image.open(io.BytesIO(response.content)).verify()
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
return
|
||||
|
||||
# 保存缓存
|
||||
cache_backend.set(cache_path.as_posix(), response.content, region="images")
|
||||
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
|
||||
ImageHelper().fetch_image(url=url)
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
|
||||
@@ -42,7 +42,7 @@ class SubscribeChain(ChainBase):
|
||||
_LOCK_TIMOUT = 3600 * 2
|
||||
|
||||
@staticmethod
|
||||
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
@@ -158,7 +158,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
else:
|
||||
# 使用TMDBID识别
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
|
||||
@@ -169,7 +169,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
if mediainfo:
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
@@ -949,7 +949,7 @@ class SubscribeChain(ChainBase):
|
||||
and torrent_mediainfo.douban_id != mediainfo.douban_id:
|
||||
continue
|
||||
logger.info(
|
||||
f'{mediainfo.title_year} 通过媒体信ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
f'{mediainfo.title_year} 通过媒体ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@@ -52,7 +52,10 @@ class UserChain(ChainBase):
|
||||
success, user_or_message = self.password_authenticate(credentials=credentials)
|
||||
if success:
|
||||
# 如果用户启用了二次验证码,则进一步验证
|
||||
if not self._verify_mfa(user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
logger.info(f"用户 {username} 通过密码认证成功")
|
||||
return True, user_or_message
|
||||
@@ -63,7 +66,10 @@ class UserChain(ChainBase):
|
||||
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
|
||||
if aux_success:
|
||||
# 辅助认证成功后再验证二次验证码
|
||||
if not self._verify_mfa(aux_user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(aux_user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
return True, aux_user_or_message
|
||||
else:
|
||||
@@ -159,22 +165,46 @@ class UserChain(ChainBase):
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
|
||||
@staticmethod
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool:
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> Union[bool, str]:
|
||||
"""
|
||||
验证 MFA(二次验证码)
|
||||
检查用户是否启用了 OTP 或 PassKey,如果启用了任何一种,都需要提供验证
|
||||
|
||||
:param user: 用户对象
|
||||
:param mfa_code: 二次验证码
|
||||
:return: 如果验证成功返回 True,否则返回 False
|
||||
:param mfa_code: 二次验证码(如果提供了则验证OTP)
|
||||
:return:
|
||||
- 如果验证成功返回 True
|
||||
- 如果需要MFA但未提供,返回 "MFA_REQUIRED"
|
||||
- 如果MFA验证失败,返回 False
|
||||
"""
|
||||
if not user.is_otp:
|
||||
# 检查用户是否有PassKey
|
||||
from app.db.models.passkey import PassKey
|
||||
has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id))
|
||||
|
||||
# 如果用户既没有启用OTP也没有PassKey,直接通过
|
||||
if not user.is_otp and not has_passkey:
|
||||
return True
|
||||
|
||||
# 如果用户启用了OTP或PassKey,但没有提供验证码,需要进行二次验证
|
||||
if not mfa_code:
|
||||
logger.info(f"用户 {user.name} 缺少 MFA 认证码")
|
||||
return False
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码")
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
# 如果提供了验证码,且用户启用了 OTP,则验证 OTP
|
||||
if user.is_otp:
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
# OTP 验证成功
|
||||
return True
|
||||
|
||||
# 用户未启用 OTP,此时提供的 mfa_code 无效;如果启用了 PassKey,则仍需通过 PassKey 验证
|
||||
if has_passkey:
|
||||
logger.info(
|
||||
f"用户 {user.name} 未启用 OTP,但已启用 PassKey,提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证"
|
||||
)
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
return True
|
||||
|
||||
def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool:
|
||||
|
||||
@@ -393,6 +393,8 @@ class ConfigModel(BaseModel):
|
||||
])
|
||||
# 允许的图片文件后缀格式
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
|
||||
# PassKey 是否强制用户验证(生物识别等)
|
||||
PASSKEY_REQUIRE_UV: bool = True
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
# 工作流数据共享
|
||||
@@ -407,6 +409,8 @@ class ConfigModel(BaseModel):
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
# Playwright浏览器类型,chromium/firefox
|
||||
PLAYWRIGHT_BROWSER_TYPE: str = "chromium"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
@@ -430,9 +434,9 @@ class ConfigModel(BaseModel):
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 50
|
||||
# 记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 30
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 30
|
||||
# 内存记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 1
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
|
||||
|
||||
@@ -95,18 +95,20 @@ class TorrentInfo:
|
||||
if upload_volume_factor is None or download_volume_factor is None:
|
||||
return "未知"
|
||||
free_strs = {
|
||||
"1.0 1.0": "普通",
|
||||
"1.0 0.0": "免费",
|
||||
"2.0 1.0": "2X",
|
||||
"4.0 1.0": "4X",
|
||||
"2.0 0.0": "2X免费",
|
||||
"4.0 0.0": "4X免费",
|
||||
"1.0 0.5": "50%",
|
||||
"2.0 0.5": "2X 50%",
|
||||
"1.0 0.7": "70%",
|
||||
"1.0 0.3": "30%"
|
||||
"1.00 1.00": "普通",
|
||||
"1.00 0.00": "免费",
|
||||
"2.00 1.00": "2X",
|
||||
"4.00 1.00": "4X",
|
||||
"2.00 0.00": "2X免费",
|
||||
"4.00 0.00": "4X免费",
|
||||
"1.00 0.50": "50%",
|
||||
"2.00 0.50": "2X 50%",
|
||||
"1.00 0.70": "70%",
|
||||
"1.00 0.30": "30%",
|
||||
"1.00 0.75": "75%",
|
||||
"1.00 0.25": "25%"
|
||||
}
|
||||
return free_strs.get('%.1f %.1f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
return free_strs.get('%.2f %.2f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
|
||||
@property
|
||||
def volume_factor(self):
|
||||
|
||||
@@ -6,11 +6,11 @@ import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -20,7 +20,7 @@ from watchfiles import watch
|
||||
from app import schemas
|
||||
from app.core.cache import fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.plugin import PluginHelper
|
||||
@@ -28,16 +28,16 @@ from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.crypto import RSAUtils
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
插件管理器
|
||||
"""
|
||||
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""插件管理器"""
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
|
||||
|
||||
def __init__(self):
|
||||
# 插件列表
|
||||
@@ -250,20 +250,12 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
return self._plugins
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']:
|
||||
return
|
||||
logger.info("配置变更,重新加载插件文件修改监测...")
|
||||
def on_config_changed(self):
|
||||
self.reload_monitor()
|
||||
|
||||
def get_reload_name(self) -> str:
|
||||
return "插件文件修改监测"
|
||||
|
||||
def reload_monitor(self):
|
||||
"""
|
||||
重新加载插件文件修改监测
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .downloadhistory import DownloadHistory, DownloadFiles
|
||||
from .mediaserver import MediaServerItem
|
||||
from .passkey import PassKey
|
||||
from .plugindata import PluginData
|
||||
from .site import Site
|
||||
from .siteicon import SiteIcon
|
||||
|
||||
126
app/db/models/passkey.py
Normal file
126
app/db/models/passkey.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text, select, ForeignKey
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.db import Base, db_query, db_update, async_db_query, async_db_update, get_id_column
|
||||
|
||||
|
||||
class PassKey(Base):
|
||||
"""
|
||||
用户PassKey凭证表
|
||||
"""
|
||||
# ID
|
||||
id = get_id_column()
|
||||
# 用户ID
|
||||
user_id = Column(Integer, ForeignKey('user.id'), nullable=False, index=True)
|
||||
# 凭证ID (credential_id)
|
||||
credential_id = Column(String, nullable=False, unique=True, index=True)
|
||||
# 凭证公钥
|
||||
public_key = Column(Text, nullable=False)
|
||||
# 签名计数器
|
||||
sign_count = Column(Integer, default=0)
|
||||
# 凭证名称(用户自定义)
|
||||
name = Column(String, default="通行密钥")
|
||||
# AAGUID (Authenticator Attestation GUID)
|
||||
aaguid = Column(String, nullable=True)
|
||||
# 创建时间
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
# 最后使用时间
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
# 是否启用
|
||||
is_active = Column(Boolean, default=True)
|
||||
# 传输方式 (usb, nfc, ble, internal)
|
||||
transports = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_user_id(cls, db: Session, user_id: int):
|
||||
"""获取用户的所有PassKey"""
|
||||
return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_user_id(cls, db: AsyncSession, user_id: int):
|
||||
"""异步获取用户的所有PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_credential_id(cls, db: Session, credential_id: str):
|
||||
"""根据凭证ID获取PassKey"""
|
||||
return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str):
|
||||
"""异步根据凭证ID获取PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_id(cls, db: Session, passkey_id: int):
|
||||
"""根据ID获取PassKey"""
|
||||
return db.query(cls).filter(cls.id == passkey_id).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_id(cls, db: AsyncSession, passkey_id: int):
|
||||
"""异步根据ID获取PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.id == passkey_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete_by_id(cls, db: Session, passkey_id: int, user_id: int):
|
||||
"""删除指定用户的PassKey"""
|
||||
passkey = db.query(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
).first()
|
||||
if passkey:
|
||||
passkey.delete(db, passkey.id)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int):
|
||||
"""异步删除指定用户的PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
)
|
||||
)
|
||||
passkey = result.scalars().first()
|
||||
if passkey:
|
||||
await passkey.async_delete(db, passkey.id)
|
||||
return True
|
||||
return False
|
||||
|
||||
@db_update
|
||||
def update_last_used(self, db: Session, sign_count: int):
|
||||
"""更新最后使用时间和签名计数"""
|
||||
self.update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
|
||||
@async_db_update
|
||||
async def async_update_last_used(self, db: AsyncSession, sign_count: int):
|
||||
"""异步更新最后使用时间和签名计数"""
|
||||
await self.async_update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
@@ -10,7 +10,7 @@ from app.utils.http import RequestUtils, cookie_parse
|
||||
|
||||
|
||||
class PlaywrightHelper:
|
||||
def __init__(self, browser_type="chromium"):
|
||||
def __init__(self, browser_type=settings.PLAYWRIGHT_BROWSER_TYPE):
|
||||
self.browser_type = browser_type
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
@@ -9,7 +9,7 @@ from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?\}\}", re.DOTALL)
|
||||
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?}}", re.DOTALL)
|
||||
|
||||
|
||||
class DirectoryHelper:
|
||||
@@ -51,7 +51,7 @@ class DirectoryHelper:
|
||||
"""
|
||||
return [d for d in self.get_library_dirs() if d.library_storage == "local"]
|
||||
|
||||
def get_dir(self, media: MediaInfo, include_unsorted: Optional[bool] = False,
|
||||
def get_dir(self, media: Optional[MediaInfo], include_unsorted: Optional[bool] = False,
|
||||
storage: Optional[str] = None, src_path: Path = None,
|
||||
target_storage: Optional[str] = None, dest_path: Path = None
|
||||
) -> Optional[schemas.TransferDirectoryConf]:
|
||||
@@ -64,11 +64,8 @@ class DirectoryHelper:
|
||||
:param src_path: 源目录,有值时直接匹配
|
||||
:param dest_path: 目标目录,有值时直接匹配
|
||||
"""
|
||||
# 处理类型
|
||||
if not media:
|
||||
return None
|
||||
# 电影/电视剧
|
||||
media_type = media.type.value
|
||||
media_type = media.type.value if media else None
|
||||
dirs = self.get_dirs()
|
||||
|
||||
# 如果存在源目录,并源目录为任一下载目录的子目录时,则进行源目录匹配,否则,允许源目录按同盘优先的逻辑匹配
|
||||
@@ -93,7 +90,7 @@ class DirectoryHelper:
|
||||
if dest_path and dest_path != Path(d.library_path):
|
||||
continue
|
||||
# 目录类型为全部的,符合条件
|
||||
if not d.media_type:
|
||||
if not media_type or not d.media_type:
|
||||
matched_dirs.append(d)
|
||||
continue
|
||||
# 目录类型相等,目录类别为全部,符合条件
|
||||
@@ -109,11 +106,27 @@ class DirectoryHelper:
|
||||
# 优先源目录同盘
|
||||
for matched_dir in matched_dirs:
|
||||
matched_path = Path(matched_dir.download_path)
|
||||
if SystemUtils.is_same_disk(matched_path, src_path):
|
||||
if self._is_same_source((src_path, storage or "local"), (matched_path, matched_dir.library_storage)):
|
||||
return matched_dir
|
||||
return matched_dirs[0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_same_source(src: Tuple[Path, str], tar: Tuple[Path, str]) -> bool:
|
||||
"""
|
||||
判断源目录和目标目录是否在同一存储盘
|
||||
|
||||
:param src: 源目录路径和存储类型
|
||||
:param tar: 目标目录路径和存储类型
|
||||
:return: 是否在同一存储盘
|
||||
"""
|
||||
src_path, src_storage = src
|
||||
tar_path, tar_storage = tar
|
||||
if "local" == tar_storage == src_storage:
|
||||
return SystemUtils.is_same_disk(src_path, tar_path)
|
||||
# 网络存储,直接比较类型
|
||||
return src_storage == tar_storage
|
||||
|
||||
@staticmethod
|
||||
def get_media_root_path(rename_format: str, rename_path: Path) -> Optional[Path]:
|
||||
"""
|
||||
|
||||
@@ -14,10 +14,8 @@ from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 定义一个全局线程池执行器
|
||||
@@ -69,25 +67,23 @@ def enable_doh(enable: bool):
|
||||
socket.getaddrinfo = _orig_getaddrinfo
|
||||
|
||||
|
||||
class DohHelper(metaclass=Singleton):
|
||||
class DohHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
DoH帮助类,用于处理DNS over HTTPS解析。
|
||||
"""
|
||||
CONFIG_WATCH = {"DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"}
|
||||
|
||||
def __init__(self):
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ["DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"]:
|
||||
return
|
||||
def on_config_changed(self):
|
||||
with _doh_lock:
|
||||
# DOH配置有变动的情况下,清空缓存
|
||||
_doh_cache.clear()
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
def get_reload_name(self):
|
||||
return 'DoH'
|
||||
|
||||
def _doh_query(resolver: str, host: str) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import cached, FileCache, AsyncFileCache
|
||||
from app.core.config import settings
|
||||
from app.utils.http import RequestUtils
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.ip import IpUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@@ -161,3 +168,121 @@ class WallpaperHelper(metaclass=Singleton):
|
||||
return wallpaper_list
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class ImageHelper(metaclass=Singleton):
|
||||
|
||||
def __init__(self):
|
||||
_base_path = settings.CACHE_PATH
|
||||
_ttl = settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600
|
||||
self.file_cache = FileCache(base=_base_path, ttl=_ttl)
|
||||
self.async_file_cache = AsyncFileCache(base=_base_path, ttl=_ttl)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_cache_path(url: str) -> str:
|
||||
"""缓存路径"""
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path(sanitized_path)
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
return cache_path.as_posix()
|
||||
|
||||
@staticmethod
|
||||
def _validate_image(content: bytes) -> bool:
|
||||
"""验证图片"""
|
||||
if not content:
|
||||
return False
|
||||
try:
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warn(f"Invalid image format: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_request_params(url: str, proxy: Optional[bool], cookies: Optional[str | dict]) -> dict:
|
||||
"""获取参数"""
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
if proxy is None:
|
||||
proxies = settings.PROXY if not (referer or IpUtils.is_internal(url)) else None
|
||||
else:
|
||||
proxies = settings.PROXY if proxy else None
|
||||
return {
|
||||
"ua": settings.NORMAL_USER_AGENT,
|
||||
"proxies": proxies,
|
||||
"referer": referer,
|
||||
"cookies": cookies,
|
||||
"accept_type": "image/avif,image/webp,image/apng,*/*",
|
||||
}
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = True,
|
||||
cookies: Optional[str | dict] = None) -> Optional[bytes]:
|
||||
"""
|
||||
获取图片(同步版本)
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
cache_path = self._prepare_cache_path(url)
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
content = self.file_cache.get(cache_path, region="images")
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = RequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
content = response.content
|
||||
# 验证图片
|
||||
if not self._validate_image(content):
|
||||
return None
|
||||
|
||||
# 保存缓存
|
||||
self.file_cache.set(cache_path, content, region="images")
|
||||
return content
|
||||
|
||||
async def async_fetch_image(
|
||||
self,
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = True,
|
||||
cookies: Optional[str | dict] = None) -> Optional[bytes]:
|
||||
"""
|
||||
获取图片(异步版本)
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
cache_path = self._prepare_cache_path(url)
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
content = await self.async_file_cache.get(cache_path, region="images")
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = await AsyncRequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
content = response.content
|
||||
# 验证图片
|
||||
if not self._validate_image(content):
|
||||
return None
|
||||
|
||||
# 保存缓存
|
||||
await self.async_file_cache.set(cache_path, content, region="images")
|
||||
return content
|
||||
352
app/helper/passkey.py
Normal file
352
app/helper/passkey.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
PassKey WebAuthn 辅助工具类
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import binascii
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from webauthn import (
|
||||
generate_registration_options,
|
||||
verify_registration_response,
|
||||
generate_authentication_options,
|
||||
verify_authentication_response,
|
||||
options_to_json
|
||||
)
|
||||
from webauthn.helpers import (
|
||||
parse_registration_credential_json,
|
||||
parse_authentication_credential_json
|
||||
)
|
||||
from webauthn.helpers.structs import (
|
||||
PublicKeyCredentialDescriptor,
|
||||
AuthenticatorTransport,
|
||||
UserVerificationRequirement,
|
||||
AuthenticatorAttachment,
|
||||
ResidentKeyRequirement,
|
||||
AuthenticatorSelectionCriteria
|
||||
)
|
||||
from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PassKeyHelper:
|
||||
"""
|
||||
PassKey WebAuthn 辅助类
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_rp_id() -> str:
|
||||
"""
|
||||
获取 Relying Party ID
|
||||
"""
|
||||
if settings.APP_DOMAIN:
|
||||
app_domain = settings.APP_DOMAIN.strip()
|
||||
# 确保存在协议前缀,以便 urlparse 正确解析主机和端口
|
||||
if not app_domain.startswith(('http://', 'https://')):
|
||||
app_domain = f'https://{app_domain}'
|
||||
parsed = urlparse(app_domain)
|
||||
host = parsed.hostname
|
||||
if host:
|
||||
return host
|
||||
# 从 APP_DOMAIN 中提取域名
|
||||
host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '')
|
||||
# 移除端口号
|
||||
if ':' in host:
|
||||
host = host.split(':')[0]
|
||||
return host
|
||||
# 只有在未配置 APP_DOMAIN 时,才默认为 localhost
|
||||
return 'localhost'
|
||||
|
||||
@staticmethod
|
||||
def get_rp_name() -> str:
|
||||
"""
|
||||
获取 Relying Party 名称
|
||||
"""
|
||||
return "MoviePilot"
|
||||
|
||||
@staticmethod
|
||||
def get_origin() -> str:
|
||||
"""
|
||||
获取源地址
|
||||
"""
|
||||
if settings.APP_DOMAIN:
|
||||
return settings.APP_DOMAIN.rstrip('/')
|
||||
# 如果未配置APP_DOMAIN,使用默认的localhost地址
|
||||
return f'http://localhost:{settings.NGINX_PORT}'
|
||||
|
||||
@staticmethod
|
||||
def standardize_credential_id(credential_id: str) -> str:
|
||||
"""
|
||||
标准化凭证ID(Base64 URL Safe)
|
||||
"""
|
||||
try:
|
||||
# Base64解码并重新编码以标准化格式
|
||||
decoded = base64.urlsafe_b64decode(credential_id + '==')
|
||||
return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=')
|
||||
except (binascii.Error, TypeError, ValueError) as e:
|
||||
logger.error(f"标准化凭证ID失败: {e}")
|
||||
return credential_id
|
||||
|
||||
@staticmethod
|
||||
def generate_registration_options(
|
||||
user_id: int,
|
||||
username: str,
|
||||
display_name: Optional[str] = None,
|
||||
existing_credentials: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
生成注册选项
|
||||
|
||||
:param user_id: 用户ID
|
||||
:param username: 用户名
|
||||
:param display_name: 显示名称
|
||||
:param existing_credentials: 已存在的凭证列表
|
||||
:return: (options_json, challenge)
|
||||
"""
|
||||
try:
|
||||
# 用户信息
|
||||
user_id_bytes = str(user_id).encode('utf-8')
|
||||
|
||||
# 排除已有的凭证
|
||||
exclude_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
exclude_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
|
||||
# 用户验证要求
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
|
||||
# 生成注册选项
|
||||
options = generate_registration_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
rp_name=PassKeyHelper.get_rp_name(),
|
||||
user_id=user_id_bytes,
|
||||
user_name=username,
|
||||
user_display_name=display_name or username,
|
||||
exclude_credentials=exclude_credentials if exclude_credentials else None,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
authenticator_attachment=None,
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
user_verification=uv_requirement,
|
||||
),
|
||||
supported_pub_key_algs=[
|
||||
COSEAlgorithmIdentifier.ECDSA_SHA_256,
|
||||
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
|
||||
]
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
# 提取challenge(用于后续验证)
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成注册选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _get_verified_origin(credential: Dict[str, Any], rp_id: str, default_origin: str) -> str:
|
||||
"""
|
||||
在 localhost 环境下获取并验证实际 Origin,否则返回默认值
|
||||
"""
|
||||
if not settings.APP_DOMAIN and rp_id == 'localhost':
|
||||
try:
|
||||
# 解析 clientDataJSON 获取实际的 origin
|
||||
client_data_json = json.loads(
|
||||
base64.urlsafe_b64decode(
|
||||
credential['response']['clientDataJSON'].replace('-', '+').replace('_', '/') + '=='
|
||||
).decode('utf-8')
|
||||
)
|
||||
actual_origin = client_data_json.get('origin', '')
|
||||
hostname = urlparse(actual_origin).hostname
|
||||
|
||||
if hostname in ['localhost', '127.0.0.1']:
|
||||
logger.info(f"本地环境,使用动态 origin: {actual_origin}")
|
||||
return actual_origin
|
||||
except Exception as e:
|
||||
logger.warning(f"无法提取动态 origin: {e}")
|
||||
return default_origin
|
||||
|
||||
@staticmethod
|
||||
def verify_registration_response(
|
||||
credential: Dict[str, Any],
|
||||
expected_challenge: str,
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[str, str, int, Optional[str]]:
|
||||
"""
|
||||
验证注册响应
|
||||
|
||||
:param credential: 客户端返回的凭证
|
||||
:param expected_challenge: 期望的challenge
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (credential_id, public_key, sign_count, aaguid)
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
# 解码challenge
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
|
||||
# 构建RegistrationCredential对象
|
||||
registration_credential = parse_registration_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证注册响应
|
||||
verification = verify_registration_response(
|
||||
credential=registration_credential,
|
||||
expected_challenge=challenge_bytes,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=origin,
|
||||
require_user_verification=settings.PASSKEY_REQUIRE_UV
|
||||
)
|
||||
|
||||
# 提取信息
|
||||
credential_id = base64.urlsafe_b64encode(verification.credential_id).decode('utf-8').rstrip('=')
|
||||
public_key = base64.urlsafe_b64encode(verification.credential_public_key).decode('utf-8').rstrip('=')
|
||||
sign_count = verification.sign_count
|
||||
# aaguid 可能已经是字符串格式,也可能是bytes
|
||||
if verification.aaguid:
|
||||
if isinstance(verification.aaguid, bytes):
|
||||
aaguid = verification.aaguid.hex()
|
||||
else:
|
||||
aaguid = str(verification.aaguid)
|
||||
else:
|
||||
aaguid = None
|
||||
|
||||
return credential_id, public_key, sign_count, aaguid
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证注册响应失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def generate_authentication_options(
|
||||
existing_credentials: Optional[List[Dict[str, Any]]] = None,
|
||||
user_verification: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
生成认证选项
|
||||
|
||||
:param existing_credentials: 已存在的凭证列表(用于限制可用凭证)
|
||||
:param user_verification: 用户验证要求,如果不指定则从配置中读取
|
||||
:return: (options_json, challenge)
|
||||
"""
|
||||
try:
|
||||
# 允许的凭证
|
||||
allow_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
allow_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
|
||||
# 用户验证要求
|
||||
if not user_verification:
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
else:
|
||||
uv_requirement = UserVerificationRequirement(user_verification)
|
||||
|
||||
# 生成认证选项
|
||||
options = generate_authentication_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
allow_credentials=allow_credentials if allow_credentials else None,
|
||||
user_verification=uv_requirement
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
# 提取challenge
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成认证选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def verify_authentication_response(
|
||||
credential: Dict[str, Any],
|
||||
expected_challenge: str,
|
||||
credential_public_key: str,
|
||||
credential_current_sign_count: int,
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
验证认证响应
|
||||
|
||||
:param credential: 客户端返回的凭证
|
||||
:param expected_challenge: 期望的challenge
|
||||
:param credential_public_key: 凭证公钥
|
||||
:param credential_current_sign_count: 当前签名计数
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (验证成功, 新的签名计数)
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
# 解码
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
public_key_bytes = base64.urlsafe_b64decode(credential_public_key + '==')
|
||||
|
||||
# 构建AuthenticationCredential对象
|
||||
authentication_credential = parse_authentication_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证认证响应
|
||||
verification = verify_authentication_response(
|
||||
credential=authentication_credential,
|
||||
expected_challenge=challenge_bytes,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=origin,
|
||||
credential_public_key=public_key_bytes,
|
||||
credential_current_sign_count=credential_current_sign_count,
|
||||
require_user_verification=settings.PASSKEY_REQUIRE_UV
|
||||
)
|
||||
|
||||
return True, verification.new_sign_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证认证响应失败: {e}")
|
||||
return False, credential_current_sign_count
|
||||
@@ -7,10 +7,8 @@ import redis
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 类型缓存集合,针对非容器简单类型
|
||||
@@ -74,16 +72,17 @@ def deserialize(value: bytes) -> Any:
|
||||
raise ValueError("Unknown serialization format")
|
||||
|
||||
|
||||
class RedisHelper(metaclass=Singleton):
|
||||
class RedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
Redis连接和操作助手类,单例模式
|
||||
|
||||
|
||||
特性:
|
||||
- 管理Redis连接池和客户端
|
||||
- 提供序列化和反序列化功能
|
||||
- 支持内存限制和淘汰策略设置
|
||||
- 提供键名生成和区域管理功能
|
||||
"""
|
||||
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -114,25 +113,17 @@ class RedisHelper(metaclass=Singleton):
|
||||
self.client = None
|
||||
raise RuntimeError("Redis connection failed") from e
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新Redis设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
|
||||
return
|
||||
logger.info("配置变更,重连Redis...")
|
||||
def on_config_changed(self):
|
||||
self.close()
|
||||
self._connect()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "Redis"
|
||||
|
||||
def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
|
||||
"""
|
||||
动态设置Redis最大内存和内存淘汰策略
|
||||
|
||||
|
||||
:param policy: 淘汰策略(如'allkeys-lru')
|
||||
"""
|
||||
try:
|
||||
@@ -310,10 +301,10 @@ class RedisHelper(metaclass=Singleton):
|
||||
logger.debug("Redis connection closed")
|
||||
|
||||
|
||||
class AsyncRedisHelper(metaclass=Singleton):
|
||||
class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
异步Redis连接和操作助手类,单例模式
|
||||
|
||||
|
||||
特性:
|
||||
- 管理异步Redis连接池和客户端
|
||||
- 提供序列化和反序列化功能
|
||||
@@ -321,6 +312,7 @@ class AsyncRedisHelper(metaclass=Singleton):
|
||||
- 提供键名生成和区域管理功能
|
||||
- 所有操作都是异步的
|
||||
"""
|
||||
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -351,25 +343,17 @@ class AsyncRedisHelper(metaclass=Singleton):
|
||||
self.client = None
|
||||
raise RuntimeError("Redis async connection failed") from e
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
async def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新Redis设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
|
||||
return
|
||||
logger.info("配置变更,重连Redis (async)...")
|
||||
async def on_config_changed(self):
|
||||
await self.close()
|
||||
await self._connect()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "Redis (async)"
|
||||
|
||||
async def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
|
||||
"""
|
||||
动态设置Redis最大内存和内存淘汰策略
|
||||
|
||||
|
||||
:param policy: 淘汰策略(如'allkeys-lru')
|
||||
"""
|
||||
try:
|
||||
|
||||
@@ -8,35 +8,32 @@ from typing import Tuple
|
||||
import docker
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class SystemHelper:
|
||||
class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
系统工具类,提供系统相关的操作和判断
|
||||
"""
|
||||
CONFIG_WATCH = {
|
||||
"DEBUG",
|
||||
"LOG_LEVEL",
|
||||
"LOG_MAX_FILE_SIZE",
|
||||
"LOG_BACKUP_COUNT",
|
||||
"LOG_FILE_FORMAT",
|
||||
"LOG_CONSOLE_FORMAT",
|
||||
}
|
||||
|
||||
__system_flag_file = "/var/log/nginx/__moviepilot__"
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新日志设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEBUG', 'LOG_LEVEL', 'LOG_MAX_FILE_SIZE', 'LOG_BACKUP_COUNT',
|
||||
'LOG_FILE_FORMAT', 'LOG_CONSOLE_FORMAT']:
|
||||
return
|
||||
logger.info("配置变更,更新日志设置...")
|
||||
def on_config_changed(self):
|
||||
logger.update_loggers()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "日志设置"
|
||||
|
||||
@staticmethod
|
||||
def can_restart() -> bool:
|
||||
"""
|
||||
|
||||
@@ -6,8 +6,7 @@ from urllib.parse import unquote
|
||||
|
||||
from torrentool.api import Torrent
|
||||
|
||||
from app.core.cache import FileCache
|
||||
from app.core.cache import TTLCache
|
||||
from app.core.cache import TTLCache, FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, TorrentInfo, MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas import Notification, NotificationConf, MediaServerConf, DownloaderConf
|
||||
from app.schemas.types import ModuleType, DownloaderType, MediaServerType, MessageChannel, StorageSchema, \
|
||||
OtherModulesType
|
||||
OtherModulesType, SystemConfigKey
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
|
||||
|
||||
class _ModuleBase(metaclass=ABCMeta):
|
||||
class _ModuleBase(ConfigReloadMixin, metaclass=ABCMeta):
|
||||
"""
|
||||
模块基类,实现对应方法,在有需要时会被自动调用,返回None代表不启用该模块,将继续执行下一模块
|
||||
输入参数与输出参数一致的,或没有输出的,可以被多个模块重复实现
|
||||
"""
|
||||
|
||||
def on_config_changed(self):
|
||||
self.init_module()
|
||||
|
||||
def get_reload_name(self):
|
||||
return self.get_name()
|
||||
|
||||
@abstractmethod
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
@@ -177,6 +185,7 @@ class _MessageBase(ServiceBase[TService, NotificationConf]):
|
||||
"""
|
||||
消息基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Notifications.value}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -224,6 +233,7 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
"""
|
||||
下载器基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Downloaders.value}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -281,12 +291,37 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
重置默认配置名称
|
||||
"""
|
||||
self._default_config_name = None
|
||||
|
||||
def normalize_path(self, path: Path, downloader: Optional[str]) -> str:
|
||||
"""
|
||||
根据下载器配置和路径映射,规范化下载路径
|
||||
|
||||
:param path: 存储路径
|
||||
:param downloader: 下载器名称
|
||||
:return: 规范化后发送给下载器的路径
|
||||
"""
|
||||
dir = path.as_posix()
|
||||
conf = self.get_config(downloader)
|
||||
if conf and conf.path_mapping:
|
||||
for (storage_path, download_path) in conf.path_mapping:
|
||||
storage_path = Path(storage_path.strip()).as_posix()
|
||||
download_path = Path(download_path.strip()).as_posix()
|
||||
if dir.startswith(storage_path):
|
||||
dir = dir.replace(storage_path, download_path, 1)
|
||||
break
|
||||
# 去掉存储协议前缀 if any, 下载器无法识别
|
||||
for s in StorageSchema:
|
||||
prefix = f"{s.value}:"
|
||||
if dir.startswith(prefix):
|
||||
return dir[len(prefix):]
|
||||
return dir
|
||||
|
||||
|
||||
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
|
||||
"""
|
||||
媒体服务器基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.MediaServers.value}
|
||||
|
||||
def get_configs(self) -> Dict[str, MediaServerConf]:
|
||||
"""
|
||||
|
||||
214
app/modules/discord/__init__.py
Normal file
214
app/modules/discord/__init__.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import json
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
try:
|
||||
from app.modules.discord.discord import Discord
|
||||
except Exception as err: # ImportError or other load issues
|
||||
Discord = None
|
||||
logger.error(f"Discord 模块未加载,缺少依赖或初始化错误:{err}")
|
||||
|
||||
|
||||
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
"""
|
||||
if not Discord:
|
||||
logger.error("Discord 依赖未就绪(需要安装 discord.py==2.6.4),模块未启动")
|
||||
return
|
||||
super().init_service(service_name=Discord.__name__.lower(),
|
||||
service_type=Discord)
|
||||
self._channel = MessageChannel.Discord
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Discord"
|
||||
|
||||
@staticmethod
|
||||
def get_type() -> ModuleType:
|
||||
"""
|
||||
获取模块类型
|
||||
"""
|
||||
return ModuleType.Notification
|
||||
|
||||
@staticmethod
|
||||
def get_subtype() -> MessageChannel:
|
||||
"""
|
||||
获取模块子类型
|
||||
"""
|
||||
return MessageChannel.Discord
|
||||
|
||||
@staticmethod
|
||||
def get_priority() -> int:
|
||||
"""
|
||||
获取模块优先级,数字越小优先级越高,只有同一接口下优先级才生效
|
||||
"""
|
||||
return 4
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止模块
|
||||
"""
|
||||
for client in self.get_instances().values():
|
||||
client.stop()
|
||||
|
||||
def test(self) -> Optional[Tuple[bool, str]]:
|
||||
"""
|
||||
测试模块连接性
|
||||
"""
|
||||
if not self.get_instances():
|
||||
return None
|
||||
for name, client in self.get_instances().items():
|
||||
state = client.get_state()
|
||||
if not state:
|
||||
return False, f"Discord {name} Bot 未就绪"
|
||||
return True, ""
|
||||
|
||||
def init_setting(self) -> Tuple[str, Union[str, bool]]:
|
||||
pass
|
||||
|
||||
def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]:
|
||||
"""
|
||||
解析消息内容,返回字典,注意以下约定值:
|
||||
userid: 用户ID
|
||||
username: 用户名
|
||||
text: 内容
|
||||
:param source: 消息来源
|
||||
:param body: 请求体
|
||||
:param form: 表单
|
||||
:param args: 参数
|
||||
:return: 渠道、消息体
|
||||
"""
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
try:
|
||||
msg_json: dict = json.loads(body)
|
||||
except Exception as e:
|
||||
logger.debug(f"解析 Discord 消息失败:{str(e)}")
|
||||
return None
|
||||
|
||||
if not msg_json:
|
||||
return None
|
||||
|
||||
msg_type = msg_json.get("type")
|
||||
userid = msg_json.get("userid")
|
||||
username = msg_json.get("username")
|
||||
|
||||
if msg_type == "interaction":
|
||||
callback_data = msg_json.get("callback_data")
|
||||
message_id = msg_json.get("message_id")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if callback_data and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 按钮回调:"
|
||||
f"userid={userid}, username={username}, callback_data={callback_data}")
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
source=client_config.name,
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=f"CALLBACK:{callback_data}",
|
||||
is_callback=True,
|
||||
callback_data=callback_data,
|
||||
message_id=message_id,
|
||||
chat_id=str(chat_id) if chat_id else None
|
||||
)
|
||||
return None
|
||||
|
||||
if msg_type == "message":
|
||||
text = msg_json.get("text")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if text and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}")
|
||||
return CommingMessage(channel=MessageChannel.Discord, source=client_config.name,
|
||||
userid=userid, username=username, text=text,
|
||||
chat_id=str(chat_id) if chat_id else None)
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送通知消息
|
||||
:param message: 消息通知对象
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
userid = targets.get('discord_userid')
|
||||
if not userid:
|
||||
logger.warn("用户没有指定 Discord 用户ID,消息无法发送")
|
||||
return
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
发送媒体信息选择列表
|
||||
:param message: 消息体
|
||||
:param medias: 媒体信息
|
||||
:return: 成功或失败
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
|
||||
"""
|
||||
发送种子信息选择列表
|
||||
:param message: 消息体
|
||||
:param torrents: 种子信息
|
||||
:return: 成功或失败
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_torrents_msg(title=message.title, torrents=torrents,
|
||||
userid=message.userid, buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def delete_message(self, channel: MessageChannel, source: str,
|
||||
message_id: str, chat_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
:param source: 指定的消息源
|
||||
:param message_id: 消息ID(Slack中为时间戳)
|
||||
:param chat_id: 聊天ID(频道ID)
|
||||
:return: 删除是否成功
|
||||
"""
|
||||
success = False
|
||||
for conf in self.get_configs().values():
|
||||
if channel != self._channel:
|
||||
break
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.delete_msg(message_id=message_id, chat_id=chat_id)
|
||||
if result:
|
||||
success = True
|
||||
return success
|
||||
519
app/modules/discord/discord.py
Normal file
519
app/modules/discord/discord.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class Discord:
|
||||
"""
|
||||
Discord Bot 通知与交互实现(基于 discord.py 2.6.4)
|
||||
"""
|
||||
|
||||
def __init__(self, DISCORD_BOT_TOKEN: Optional[str] = None,
|
||||
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
|
||||
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
|
||||
**kwargs):
|
||||
if not DISCORD_BOT_TOKEN:
|
||||
logger.error("Discord Bot Token 未配置!")
|
||||
return
|
||||
|
||||
self._token = DISCORD_BOT_TOKEN
|
||||
self._guild_id = self._to_int(DISCORD_GUILD_ID)
|
||||
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
|
||||
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
|
||||
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
|
||||
if kwargs.get("name"):
|
||||
self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}"
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.messages = True
|
||||
intents.guilds = True
|
||||
|
||||
self._client: Optional[discord.Client] = discord.Client(intents=intents)
|
||||
self._tree: Optional[app_commands.CommandTree] = None
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready_event = threading.Event()
|
||||
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
|
||||
self._broadcast_channel = None
|
||||
self._bot_user_id: Optional[int] = None
|
||||
|
||||
self._register_events()
|
||||
self._start()
|
||||
|
||||
@staticmethod
|
||||
def _to_int(val: Optional[Union[str, int]]) -> Optional[int]:
|
||||
try:
|
||||
return int(val) if val is not None and str(val).strip() else None
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _register_events(self):
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
self._bot_user_id = self._client.user.id if self._client.user else None
|
||||
self._ready_event.set()
|
||||
logger.info(f"Discord Bot 已登录:{self._client.user}")
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: discord.Message):
|
||||
if message.author.bot:
|
||||
return
|
||||
if not self._should_process_message(message):
|
||||
return
|
||||
|
||||
cleaned_text = self._clean_bot_mention(message.content or "")
|
||||
username = message.author.display_name or message.author.global_name or message.author.name
|
||||
payload = {
|
||||
"type": "message",
|
||||
"userid": str(message.author.id),
|
||||
"username": username,
|
||||
"user_tag": str(message.author),
|
||||
"text": cleaned_text,
|
||||
"message_id": str(message.id),
|
||||
"chat_id": str(message.channel.id),
|
||||
"channel_type": "dm" if isinstance(message.channel, discord.DMChannel) else "guild"
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
@self._client.event
|
||||
async def on_interaction(interaction: discord.Interaction):
|
||||
if interaction.type == discord.InteractionType.component:
|
||||
data = interaction.data or {}
|
||||
callback_data = data.get("custom_id")
|
||||
if not callback_data:
|
||||
return
|
||||
try:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Discord 交互响应失败:{e}")
|
||||
|
||||
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
|
||||
if interaction.user else None
|
||||
payload = {
|
||||
"type": "interaction",
|
||||
"userid": str(interaction.user.id) if interaction.user else None,
|
||||
"username": username,
|
||||
"user_tag": str(interaction.user) if interaction.user else None,
|
||||
"callback_data": callback_data,
|
||||
"message_id": str(interaction.message.id) if interaction.message else None,
|
||||
"chat_id": str(interaction.channel.id) if interaction.channel else None
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
def _start(self):
|
||||
if self._thread:
|
||||
return
|
||||
|
||||
def runner():
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.create_task(self._client.start(self._token))
|
||||
self._loop.run_forever()
|
||||
except Exception as err:
|
||||
logger.error(f"Discord Bot 启动失败:{err}")
|
||||
finally:
|
||||
try:
|
||||
self._loop.run_until_complete(self._client.close())
|
||||
except Exception as err:
|
||||
logger.debug(f"Discord Bot 关闭失败:{err}")
|
||||
|
||||
self._thread = threading.Thread(target=runner, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
if not self._client or not self._loop or not self._thread:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(timeout=10)
|
||||
except Exception as err:
|
||||
logger.error(f"关闭 Discord Bot 失败:{err}")
|
||||
finally:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
except Exception as err:
|
||||
logger.error(f"停止 Discord 事件循环失败:{err}")
|
||||
self._ready_event.clear()
|
||||
|
||||
def get_state(self) -> bool:
|
||||
return self._ready_event.is_set() and self._client is not None
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
if not title and not text:
|
||||
logger.warn("标题和内容不能同时为空")
|
||||
return False
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_message(title=title, text=text, image=image, userid=userid,
|
||||
link=link, buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id),
|
||||
self._loop)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state() or not medias:
|
||||
return False
|
||||
title = title or "媒体列表"
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_list_message(
|
||||
embeds=self._build_media_embeds(medias, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(medias)) if not buttons else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 媒体列表失败:{err}")
|
||||
return False
|
||||
|
||||
def send_torrents_msg(self, torrents: List[Context], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state() or not torrents:
|
||||
return False
|
||||
title = title or "种子列表"
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_list_message(
|
||||
embeds=self._build_torrent_embeds(torrents, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(torrents)) if not buttons else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 种子列表失败:{err}")
|
||||
return False
|
||||
|
||||
def delete_msg(self, message_id: Union[str, int], chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._delete_message(message_id=message_id, chat_id=chat_id),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=15)
|
||||
except Exception as err:
|
||||
logger.error(f"删除 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _send_message(self, title: str, text: Optional[str], image: Optional[str],
|
||||
userid: Optional[str], link: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str]) -> bool:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
|
||||
embed = self._build_embed(title=title, text=text, image=image, link=link)
|
||||
view = self._build_view(buttons=buttons, link=link)
|
||||
content = None
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=content, embed=embed, view=view)
|
||||
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
return True
|
||||
|
||||
async def _send_list_message(self, embeds: List[discord.Embed],
|
||||
userid: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
fallback_buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str]) -> bool:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
|
||||
view = self._build_view(buttons=buttons if buttons else fallback_buttons)
|
||||
embeds = embeds[:10] if embeds else [] # Discord 单条消息最多 10 个 embed
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=None, embed=None, view=view, embeds=embeds)
|
||||
|
||||
await channel.send(embed=embeds[0] if len(embeds) == 1 else None,
|
||||
embeds=embeds if len(embeds) > 1 else None,
|
||||
view=view)
|
||||
return True
|
||||
|
||||
async def _edit_message(self, chat_id: Union[str, int], message_id: Union[str, int],
|
||||
content: Optional[str], embed: Optional[discord.Embed],
|
||||
view: Optional[discord.ui.View], embeds: Optional[List[discord.Embed]] = None) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=str(chat_id))
|
||||
if not channel:
|
||||
logger.error(f"未找到要编辑的 Discord 频道:{chat_id}")
|
||||
return False
|
||||
try:
|
||||
message = await channel.fetch_message(int(message_id))
|
||||
kwargs: Dict[str, Any] = {"content": content, "view": view}
|
||||
if embeds:
|
||||
if len(embeds) == 1:
|
||||
kwargs["embed"] = embeds[0]
|
||||
else:
|
||||
kwargs["embeds"] = embeds
|
||||
elif embed:
|
||||
kwargs["embed"] = embed
|
||||
await message.edit(**kwargs)
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"编辑 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _delete_message(self, message_id: Union[str, int], chat_id: Optional[str]) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=chat_id)
|
||||
if not channel:
|
||||
logger.error("删除 Discord 消息时未找到频道")
|
||||
return False
|
||||
try:
|
||||
message = await channel.fetch_message(int(message_id))
|
||||
await message.delete()
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"删除 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_embed(title: str, text: Optional[str], image: Optional[str],
|
||||
link: Optional[str]) -> discord.Embed:
|
||||
description = ""
|
||||
fields: List[Dict[str, str]] = []
|
||||
if text:
|
||||
for line in text.splitlines():
|
||||
if ":" in line:
|
||||
name, value = line.split(":", 1)
|
||||
fields.append({"name": name.strip(), "value": value.strip() or "-", "inline": False})
|
||||
else:
|
||||
description += f"{line}\n"
|
||||
description = description.strip()
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
url=link or "https://github.com/jxxghp/MoviePilot",
|
||||
description=description if description else (text or None),
|
||||
color=0xE67E22
|
||||
)
|
||||
for field in fields:
|
||||
embed.add_field(name=field["name"], value=field["value"], inline=False)
|
||||
if image:
|
||||
embed.set_image(url=image)
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
def _build_media_embeds(medias: List[MediaInfo], title: str) -> List[discord.Embed]:
|
||||
embeds: List[discord.Embed] = []
|
||||
for index, media in enumerate(medias[:10], start=1):
|
||||
overview = media.get_overview_string(80)
|
||||
desc_parts = [
|
||||
f"{media.type.value} | {media.vote_star}" if media.vote_star else media.type.value,
|
||||
overview
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {media.title_year}",
|
||||
url=media.detail_link or discord.Embed.Empty,
|
||||
description="\n".join([p for p in desc_parts if p]),
|
||||
color=0x5865F2
|
||||
)
|
||||
if media.get_poster_image():
|
||||
embed.set_thumbnail(url=media.get_poster_image())
|
||||
embeds.append(embed)
|
||||
if embeds:
|
||||
embeds[0].set_author(name=title)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def _build_torrent_embeds(torrents: List[Context], title: str) -> List[discord.Embed]:
|
||||
embeds: List[discord.Embed] = []
|
||||
for index, context in enumerate(torrents[:10], start=1):
|
||||
torrent = context.torrent_info
|
||||
meta = MetaInfo(torrent.title, torrent.description)
|
||||
title_text = f"{meta.season_episode} {meta.resource_term} {meta.video_term} {meta.release_group}"
|
||||
title_text = re.sub(r"\s+", " ", title_text).strip()
|
||||
detail = [
|
||||
f"{torrent.site_name} | {StringUtils.str_filesize(torrent.size)} | {torrent.volume_factor} | {torrent.seeders}↑",
|
||||
meta.resource_term,
|
||||
meta.video_term
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {title_text or torrent.title}",
|
||||
url=torrent.page_url or discord.Embed.Empty,
|
||||
description="\n".join([d for d in detail if d]),
|
||||
color=0x00A86B
|
||||
)
|
||||
poster = getattr(torrent, "poster", None)
|
||||
if poster:
|
||||
embed.set_thumbnail(url=poster)
|
||||
embeds.append(embed)
|
||||
if embeds:
|
||||
embeds[0].set_author(name=title)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def _build_default_buttons(count: int) -> List[List[dict]]:
|
||||
buttons: List[List[dict]] = []
|
||||
max_rows = 5
|
||||
max_per_row = 5
|
||||
capped = min(count, max_rows * max_per_row)
|
||||
for idx in range(1, capped + 1):
|
||||
row_idx = (idx - 1) // max_per_row
|
||||
if len(buttons) <= row_idx:
|
||||
buttons.append([])
|
||||
buttons[row_idx].append({"text": f"选择 {idx}", "callback_data": str(idx)})
|
||||
if count > capped:
|
||||
logger.warn(f"按钮数量超过 Discord 限制,仅展示前 {capped} 个")
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _build_view(buttons: Optional[List[List[dict]]], link: Optional[str] = None) -> Optional[discord.ui.View]:
|
||||
has_buttons = buttons and any(buttons)
|
||||
if not has_buttons and not link:
|
||||
return None
|
||||
|
||||
view = discord.ui.View(timeout=None)
|
||||
if buttons:
|
||||
for row_index, button_row in enumerate(buttons[:5]):
|
||||
for button in button_row[:5]:
|
||||
if "url" in button:
|
||||
btn = discord.ui.Button(label=button.get("text", "链接"),
|
||||
url=button["url"],
|
||||
style=discord.ButtonStyle.link)
|
||||
else:
|
||||
custom_id = (button.get("callback_data") or button.get("text") or f"btn-{row_index}")[:99]
|
||||
btn = discord.ui.Button(label=button.get("text", "选择")[:80],
|
||||
custom_id=custom_id,
|
||||
style=discord.ButtonStyle.primary)
|
||||
view.add_item(btn)
|
||||
elif link:
|
||||
view.add_item(discord.ui.Button(label="查看详情", url=link, style=discord.ButtonStyle.link))
|
||||
return view
|
||||
|
||||
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
|
||||
# 优先使用明确的聊天 ID
|
||||
if chat_id:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
return channel
|
||||
try:
|
||||
return await self._client.fetch_channel(int(chat_id))
|
||||
except Exception as err:
|
||||
logger.warn(f"通过 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# 私聊
|
||||
if userid:
|
||||
dm = await self._get_dm_channel(str(userid))
|
||||
if dm:
|
||||
return dm
|
||||
|
||||
# 配置的广播频道
|
||||
if self._broadcast_channel:
|
||||
return self._broadcast_channel
|
||||
if self._channel_id:
|
||||
channel = self._client.get_channel(self._channel_id)
|
||||
if not channel:
|
||||
try:
|
||||
channel = await self._client.fetch_channel(self._channel_id)
|
||||
except Exception as err:
|
||||
logger.warn(f"通过配置的频道ID获取 Discord 频道失败:{err}")
|
||||
channel = None
|
||||
self._broadcast_channel = channel
|
||||
if channel:
|
||||
return channel
|
||||
|
||||
# 按 Guild 寻找一个可用文本频道
|
||||
target_guilds = []
|
||||
if self._guild_id:
|
||||
guild = self._client.get_guild(self._guild_id)
|
||||
if guild:
|
||||
target_guilds.append(guild)
|
||||
else:
|
||||
target_guilds = list(self._client.guilds)
|
||||
|
||||
for guild in target_guilds:
|
||||
for channel in guild.text_channels:
|
||||
if guild.me and channel.permissions_for(guild.me).send_messages:
|
||||
self._broadcast_channel = channel
|
||||
return channel
|
||||
return None
|
||||
|
||||
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
|
||||
if userid in self._user_dm_cache:
|
||||
return self._user_dm_cache.get(userid)
|
||||
try:
|
||||
user_obj = self._client.get_user(int(userid)) or await self._client.fetch_user(int(userid))
|
||||
if not user_obj:
|
||||
return None
|
||||
dm = user_obj.dm_channel or await user_obj.create_dm()
|
||||
if dm:
|
||||
self._user_dm_cache[userid] = dm
|
||||
return dm
|
||||
except Exception as err:
|
||||
logger.error(f"获取 Discord 私聊失败:{err}")
|
||||
return None
|
||||
|
||||
def _should_process_message(self, message: discord.Message) -> bool:
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return True
|
||||
content = message.content or ""
|
||||
# 仅处理 @Bot 或斜杠命令
|
||||
if self._client.user and self._client.user.mentioned_in(message):
|
||||
return True
|
||||
if content.startswith("/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clean_bot_mention(self, content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
if self._bot_user_id:
|
||||
mention_pattern = rf"<@!?{self._bot_user_id}>"
|
||||
content = re.sub(mention_pattern, "", content).strip()
|
||||
return content
|
||||
|
||||
async def _post_to_ds(self, payload: Dict[str, Any]) -> None:
|
||||
try:
|
||||
proxy = None
|
||||
if settings.PROXY:
|
||||
proxy = settings.PROXY.get("https") or settings.PROXY.get("http")
|
||||
async with httpx.AsyncClient(timeout=10, verify=False, proxy=proxy) as client:
|
||||
await client.post(self._ds_url, json=payload)
|
||||
except Exception as err:
|
||||
logger.error(f"转发 Discord 消息失败:{err}")
|
||||
@@ -2,11 +2,11 @@ from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.emby.emby import Emby
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType, SystemConfigKey, EventType
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType
|
||||
|
||||
|
||||
class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
|
||||
@@ -18,20 +18,6 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
|
||||
super().init_service(service_name=Emby.__name__.lower(),
|
||||
service_type=lambda conf: Emby(**conf.config, sync_libraries=conf.sync_libraries))
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新初始化Emby模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Emby"
|
||||
|
||||
@@ -640,7 +640,7 @@ class Emby:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
year=item.get("ProductionYear"),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
|
||||
@@ -37,6 +37,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
transtype = {
|
||||
"move": "移动",
|
||||
"copy": "复制",
|
||||
"link": "硬链接",
|
||||
}
|
||||
|
||||
# 文件块大小,默认10MB
|
||||
@@ -635,7 +636,39 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
return False
|
||||
|
||||
def link(self, fileitem: schemas.FileItem, target_file: Path) -> bool:
|
||||
pass
|
||||
"""
|
||||
硬链接文件
|
||||
Samba服务器需要开启 unix extensions 支持
|
||||
"""
|
||||
try:
|
||||
self._check_connection()
|
||||
src_path = self._normalize_path(fileitem.path)
|
||||
dst_path = self._normalize_path(target_file)
|
||||
|
||||
# 检查源文件是否存在
|
||||
if not smbclient.path.exists(src_path):
|
||||
raise FileNotFoundError(f"源文件不存在: {src_path}")
|
||||
|
||||
# 确保目标路径的父目录存在
|
||||
dst_parent = "\\".join(dst_path.rsplit("\\", 1)[:-1])
|
||||
if dst_parent and not smbclient.path.exists(dst_parent):
|
||||
logger.info(f"【SMB】创建目标目录: {dst_parent}")
|
||||
smbclient.makedirs(dst_parent, exist_ok=True)
|
||||
|
||||
# 尝试创建硬链接
|
||||
smbclient.link(src_path, dst_path)
|
||||
logger.info(f"【SMB】硬链接创建成功: {src_path} -> {dst_path}")
|
||||
return True
|
||||
|
||||
except SMBResponseException as e:
|
||||
# SMB协议错误,可能不支持硬链接
|
||||
logger.error(f"【SMB】创建硬链接失败(当前Samba服务器可能不支持硬链接): {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】创建硬链接失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def softlink(self, fileitem: schemas.FileItem, target_file: Path) -> bool:
|
||||
pass
|
||||
|
||||
@@ -418,6 +418,9 @@ class TransHandler:
|
||||
return None, f"{fileitem.path} {fileitem.storage} 下载失败"
|
||||
elif fileitem.storage == target_storage:
|
||||
# 同一网盘
|
||||
if not source_oper.is_support_transtype(transfer_type):
|
||||
return None, f"存储 {fileitem.storage} 不支持 {transfer_type} 整理方式"
|
||||
|
||||
if transfer_type == "copy":
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
@@ -438,6 +441,11 @@ class TransHandler:
|
||||
return None, f"【{target_storage}】{fileitem.path} 移动文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "link":
|
||||
if source_oper.link(fileitem, target_file):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 创建硬链接失败"
|
||||
else:
|
||||
return None, f"不支持的整理方式:{transfer_type}"
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ class GazelleSiteUserInfo(SiteParserBase):
|
||||
html_text = self._prepare_html_text(html_text)
|
||||
html = etree.HTML(html_text)
|
||||
try:
|
||||
tmps = html.xpath('//a[contains(@href, "user.php?id=")]')
|
||||
tmps = html.xpath('//a[contains(@href, "user.php?id=") or contains(@href, "user?id=")]')
|
||||
if tmps:
|
||||
user_id_match = re.search(r"user.php\?id=(\d+)", tmps[0].attrib['href'])
|
||||
user_id_match = re.search(r"user(?:\.php)?\?id=(\d+)", tmps[0].attrib['href'])
|
||||
if user_id_match and user_id_match.group().strip():
|
||||
self.userid = user_id_match.group(1)
|
||||
self._torrent_seeding_page = f"torrents.php?type=seeding&userid={self.userid}"
|
||||
@@ -42,13 +42,13 @@ class GazelleSiteUserInfo(SiteParserBase):
|
||||
|
||||
self.ratio = 0.0 if self.download <= 0.0 else round(self.upload / self.download, 3)
|
||||
|
||||
tmps = html.xpath('//a[contains(@href, "bonus.php")]/@data-tooltip')
|
||||
tmps = html.xpath('//a[contains(@href, "bonus")]/@data-tooltip')
|
||||
if tmps:
|
||||
bonus_match = re.search(r"([\d,.]+)", tmps[0])
|
||||
if bonus_match and bonus_match.group(1).strip():
|
||||
self.bonus = StringUtils.str_float(bonus_match.group(1))
|
||||
else:
|
||||
tmps = html.xpath('//a[contains(@href, "bonus.php")]')
|
||||
tmps = html.xpath('//a[contains(@href, "bonus")]')
|
||||
if tmps:
|
||||
bonus_text = tmps[0].xpath("string(.)")
|
||||
bonus_match = re.search(r"([\d,.]+)", bonus_text)
|
||||
@@ -142,7 +142,7 @@ class GazelleSiteUserInfo(SiteParserBase):
|
||||
|
||||
# 是否存在下页数据
|
||||
next_page = None
|
||||
next_page_text = html.xpath('//a[contains(.//text(), "Next") or contains(.//text(), "下一页")]/@href')
|
||||
next_page_text = html.xpath('//a[contains(.//text(), "Next") or contains(.//text(), "下一页") or contains(@title, "下一页") or contains(@title, "Next")]/@href')
|
||||
if next_page_text:
|
||||
next_page = next_page_text[-1].strip()
|
||||
finally:
|
||||
|
||||
@@ -52,7 +52,7 @@ class TYemaSiteUserInfo(SiteParserBase):
|
||||
user_info = detail.get("data", {})
|
||||
self.userid = user_info.get("id")
|
||||
self.username = user_info.get("name")
|
||||
self.user_level = user_info.get("level")
|
||||
self.user_level = str(user_info.get("level")) if user_info.get("level") is not None else None
|
||||
self.join_at = StringUtils.unify_datetime_str(user_info.get("registerTime"))
|
||||
|
||||
self.upload = user_info.get('uploadSize')
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.jellyfin.jellyfin import Jellyfin
|
||||
from app.schemas import AuthCredentials, AuthInterceptCredentials
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType, SystemConfigKey, EventType
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType
|
||||
|
||||
|
||||
class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
|
||||
@@ -19,20 +19,6 @@ class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
|
||||
super().init_service(service_name=Jellyfin.__name__.lower(),
|
||||
service_type=lambda conf: Jellyfin(**conf.config, sync_libraries=conf.sync_libraries))
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新初始化Jellyfin模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Jellyfin"
|
||||
|
||||
@@ -732,7 +732,7 @@ class Jellyfin:
|
||||
item_type=item.get("Type"),
|
||||
title=item.get("Name"),
|
||||
original_title=item.get("OriginalTitle"),
|
||||
year=str(item.get("ProductionYear")),
|
||||
year=item.get("ProductionYear"),
|
||||
tmdbid=int(tmdbid) if tmdbid else None,
|
||||
imdbid=item.get("ProviderIds", {}).get("Imdb"),
|
||||
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Optional, Tuple, Union, Any, List, Generator
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MediaServerBase
|
||||
from app.modules.plex.plex import Plex
|
||||
from app.schemas import AuthCredentials, AuthInterceptCredentials
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType, SystemConfigKey, EventType
|
||||
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType
|
||||
|
||||
|
||||
class PlexModule(_ModuleBase, _MediaServerBase[Plex]):
|
||||
@@ -19,20 +19,6 @@ class PlexModule(_ModuleBase, _MediaServerBase[Plex]):
|
||||
super().init_service(service_name=Plex.__name__.lower(),
|
||||
service_type=lambda conf: Plex(**conf.config, sync_libraries=conf.sync_libraries))
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新初始化Plex模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Plex"
|
||||
|
||||
@@ -509,7 +509,7 @@ class Plex:
|
||||
item_type=item.type,
|
||||
title=item.title,
|
||||
original_title=item.originalTitle,
|
||||
year=str(item.year),
|
||||
year=item.year,
|
||||
tmdbid=ids.get("tmdb_id"),
|
||||
imdbid=ids.get("imdb_id"),
|
||||
tvdbid=ids.get("tvdb_id"),
|
||||
|
||||
@@ -7,13 +7,12 @@ from torrentool.torrent import Torrent
|
||||
from app import schemas
|
||||
from app.core.cache import FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _DownloaderBase
|
||||
from app.modules.qbittorrent.qbittorrent import Qbittorrent
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType, SystemConfigKey, EventType
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
@@ -26,20 +25,6 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
super().init_service(service_name=Qbittorrent.__name__.lower(),
|
||||
service_type=Qbittorrent)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Downloaders.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Qbittorrent模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Qbittorrent"
|
||||
@@ -139,12 +124,12 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
return None, None, None, "下载内容为空"
|
||||
|
||||
# 读取种子的名称
|
||||
torrent, content = __get_torrent_info()
|
||||
torrent_from_file, content = __get_torrent_info()
|
||||
# 检查是否为磁力链接
|
||||
is_magnet = isinstance(content, str) and content.startswith("magnet:") or isinstance(content,
|
||||
bytes) and content.startswith(
|
||||
b"magnet:")
|
||||
if not torrent and not is_magnet:
|
||||
if not torrent_from_file and not is_magnet:
|
||||
return None, None, None, f"添加种子任务失败:无法读取种子文件"
|
||||
|
||||
# 获取下载器
|
||||
@@ -165,7 +150,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
# 添加任务
|
||||
state = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=str(download_dir),
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
tag=tags,
|
||||
cookie=cookie,
|
||||
@@ -185,8 +170,8 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
try:
|
||||
for torrent in torrents:
|
||||
# 名称与大小相等则认为是同一个种子
|
||||
if torrent.get("name") == torrent.name \
|
||||
and torrent.get("total_size") == torrent.total_size:
|
||||
if torrent.get("name") == getattr(torrent_from_file, 'name', '') \
|
||||
and torrent.get("total_size") == getattr(torrent_from_file, 'total_size', 0):
|
||||
torrent_hash = torrent.get("hash")
|
||||
torrent_tags = [str(tag).strip() for tag in torrent.get("tags").split(',')]
|
||||
logger.warn(f"下载器中已存在该种子任务:{torrent_hash} - {torrent.get('name')}")
|
||||
|
||||
@@ -3,12 +3,11 @@ import re
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.slack.slack import Slack
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
@@ -21,20 +20,6 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
service_type=Slack)
|
||||
self._channel = MessageChannel.Slack
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Slack模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Slack"
|
||||
|
||||
@@ -5,11 +5,13 @@ from typing import Tuple, Union
|
||||
|
||||
from lxml import etree
|
||||
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase
|
||||
from app.schemas.file import FileURI
|
||||
from app.schemas.types import ModuleType, OtherModulesType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
@@ -87,15 +89,33 @@ class SubtitleModule(_ModuleBase):
|
||||
# 获取种子信息
|
||||
folder_name, _ = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
# 文件保存目录,如果是单文件种子,则folder_name是空,此时文件保存目录就是下载目录
|
||||
download_dir = download_dir / folder_name
|
||||
storageChain = StorageChain()
|
||||
# 等待目录存在
|
||||
working_dir_item = None
|
||||
# split download_dir into storage and path
|
||||
fileURI = FileURI.from_uri(download_dir.as_posix())
|
||||
storage = fileURI.storage
|
||||
download_dir = Path(fileURI.path)
|
||||
for _ in range(30):
|
||||
if download_dir.exists():
|
||||
found = storageChain.get_file_item(storage, download_dir / folder_name)
|
||||
if found:
|
||||
working_dir_item = found
|
||||
break
|
||||
time.sleep(1)
|
||||
# 目录仍然不存在,且有文件夹名,则创建目录
|
||||
if not download_dir.exists() and folder_name:
|
||||
download_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not working_dir_item and folder_name:
|
||||
parent_dir_item = storageChain.get_file_item(storage, download_dir)
|
||||
if parent_dir_item:
|
||||
working_dir_item = storageChain.create_folder(
|
||||
parent_dir_item,
|
||||
folder_name
|
||||
)
|
||||
else:
|
||||
logger.error(f"下载根目录不存在,无法创建字幕文件夹:{download_dir}")
|
||||
return
|
||||
if not working_dir_item:
|
||||
logger.error(f"下载目录不存在,无法保存字幕:{download_dir / folder_name}")
|
||||
return
|
||||
# 读取网站代码
|
||||
request = RequestUtils(cookies=torrent.site_cookie, ua=torrent.site_ua)
|
||||
res = request.get_res(torrent.page_url)
|
||||
@@ -144,12 +164,12 @@ class SubtitleModule(_ModuleBase):
|
||||
shutil.unpack_archive(zip_file, zip_path, format='zip')
|
||||
# 遍历转移文件
|
||||
for sub_file in SystemUtils.list_files(zip_path, settings.RMT_SUBEXT):
|
||||
target_sub_file = download_dir / sub_file.name
|
||||
if target_sub_file.exists():
|
||||
target_sub_file = Path(working_dir_item.path) / Path(sub_file.name)
|
||||
if storageChain.get_file_item(storage, target_sub_file):
|
||||
logger.info(f"字幕文件已存在:{target_sub_file}")
|
||||
continue
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file} ...")
|
||||
SystemUtils.copy(sub_file, target_sub_file)
|
||||
storageChain.upload_file(working_dir_item, sub_file)
|
||||
# 删除临时文件
|
||||
try:
|
||||
shutil.rmtree(zip_path)
|
||||
@@ -160,9 +180,12 @@ class SubtitleModule(_ModuleBase):
|
||||
sub_file = settings.TEMP_PATH / file_name
|
||||
# 保存
|
||||
sub_file.write_bytes(ret.content)
|
||||
target_sub_file = download_dir / sub_file.name
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file}")
|
||||
SystemUtils.copy(sub_file, target_sub_file)
|
||||
target_sub_file = Path(working_dir_item.path) / Path(sub_file.name)
|
||||
if storageChain.get_file_item(storage, target_sub_file):
|
||||
logger.info(f"字幕文件已存在:{target_sub_file}")
|
||||
continue
|
||||
logger.info(f"转移字幕 {sub_file} 到 {target_sub_file} ...")
|
||||
storageChain.upload_file(working_dir_item, sub_file)
|
||||
else:
|
||||
logger.error(f"下载字幕文件失败:{sublink}")
|
||||
continue
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.synologychat.synologychat import SynologyChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
@@ -19,20 +18,6 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
service_type=SynologyChat)
|
||||
self._channel = MessageChannel.SynologyChat
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载SynologyChat模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Synology Chat"
|
||||
|
||||
@@ -41,7 +41,7 @@ class SynologyChat:
|
||||
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
发送Telegram消息
|
||||
发送SynologyChat消息
|
||||
:param title: 消息标题
|
||||
:param text: 消息内容
|
||||
:param image: 消息图片地址
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
from typing import Dict, Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.telegram.telegram import Telegram
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData, ConfigChangeEventData, \
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData, \
|
||||
NotificationConf
|
||||
from app.schemas.types import ModuleType, ChainEventType, SystemConfigKey, EventType
|
||||
from app.schemas.types import ModuleType, ChainEventType
|
||||
from app.utils.structures import DictUtils
|
||||
|
||||
|
||||
@@ -26,20 +24,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
service_type=Telegram)
|
||||
self._channel = MessageChannel.Telegram
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Telegram模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Telegram"
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from threading import Event
|
||||
from typing import Optional, List, Dict, Callable
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import telebot
|
||||
from telebot import TeleBot, apihelper
|
||||
from telebot.types import BotCommand, InlineKeyboardMarkup, InlineKeyboardButton, InputMediaPhoto
|
||||
from telegramify_markdown import standardize, telegramify
|
||||
from telegramify_markdown.type import ContentTypes, SentType
|
||||
from telebot import apihelper
|
||||
from telebot.types import InlineKeyboardMarkup, InlineKeyboardButton
|
||||
from telebot.types import InputMediaPhoto
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.utils.common import retry
|
||||
from app.utils.http import RequestUtils
|
||||
@@ -28,8 +26,7 @@ class RetryException(Exception):
|
||||
|
||||
class Telegram:
|
||||
_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}"
|
||||
_event = Event()
|
||||
_bot: telebot.TeleBot = None
|
||||
_bot: TeleBot = None
|
||||
_callback_handlers: Dict[str, Callable] = {} # 存储回调处理器
|
||||
_user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
_bot_username: Optional[str] = None # Bot username for mention detection
|
||||
@@ -54,7 +51,7 @@ class Telegram:
|
||||
else:
|
||||
apihelper.proxy = settings.PROXY
|
||||
# bot
|
||||
_bot = telebot.TeleBot(self._telegram_token, parse_mode="MarkdownV2")
|
||||
_bot = TeleBot(self._telegram_token, parse_mode="MarkdownV2")
|
||||
# 记录句柄
|
||||
self._bot = _bot
|
||||
# 获取并存储bot用户名用于@检测
|
||||
@@ -536,10 +533,10 @@ class Telegram:
|
||||
'reply_markup': reply_markup
|
||||
}
|
||||
|
||||
try:
|
||||
# 处理图片
|
||||
image = self.__process_image(image) if image else None
|
||||
# 处理图片
|
||||
image = self.__process_image(image)
|
||||
|
||||
try:
|
||||
# 图片消息的标题长度限制为1024,文本消息为4096
|
||||
caption_limit = 1024 if image else 4096
|
||||
if len(caption) < caption_limit:
|
||||
@@ -553,23 +550,17 @@ class Telegram:
|
||||
logger.error(f"发送Telegram消息失败: {e}")
|
||||
return False
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __process_image(self, image_url: str) -> bytes:
|
||||
@staticmethod
|
||||
def __process_image(image_url: Optional[str]) -> Optional[bytes]:
|
||||
"""
|
||||
处理图片URL,获取图片内容
|
||||
"""
|
||||
try:
|
||||
res = RequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
ua=settings.NORMAL_USER_AGENT
|
||||
).get_res(image_url)
|
||||
|
||||
if not res or not res.content:
|
||||
raise RetryException("获取图片失败")
|
||||
|
||||
return res.content
|
||||
except Exception as e:
|
||||
raise
|
||||
if not image_url:
|
||||
return None
|
||||
image = ImageHelper().fetch_image(image_url)
|
||||
if not image:
|
||||
logger.warn(f"图片获取失败: {image_url},仅发送文本消息")
|
||||
return image
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_short_message(self, image: Optional[bytes], caption: str, **kwargs):
|
||||
@@ -588,7 +579,7 @@ class Telegram:
|
||||
text=standardize(caption),
|
||||
**kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise RetryException(f"发送{'图片' if image else '文本'}消息失败")
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
@@ -637,7 +628,7 @@ class Telegram:
|
||||
try:
|
||||
raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e
|
||||
except NameError:
|
||||
raise RetryException("发送长消息失败") from e
|
||||
raise
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]):
|
||||
"""
|
||||
@@ -650,7 +641,7 @@ class Telegram:
|
||||
self._bot.delete_my_commands()
|
||||
self._bot.set_my_commands(
|
||||
commands=[
|
||||
telebot.types.BotCommand(cmd[1:], str(desc.get("description"))) for cmd, desc in
|
||||
BotCommand(cmd[1:], str(desc.get("description"))) for cmd, desc in
|
||||
commands.items()
|
||||
]
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ from app.modules.themoviedb.category import CategoryHelper
|
||||
from app.modules.themoviedb.scraper import TmdbScraper
|
||||
from app.modules.themoviedb.tmdb_cache import TmdbCache
|
||||
from app.modules.themoviedb.tmdbapi import TmdbApi
|
||||
from app.schemas import MediaPerson
|
||||
from app.schemas.types import MediaType, MediaImageType, ModuleType, MediaRecognizeType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
@@ -23,6 +22,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
"""
|
||||
TMDB媒体信息匹配
|
||||
"""
|
||||
CONFIG_WATCH = {"PROXY_HOST", "TMDB_API_DOMAIN", "TMDB_API_KEY", "TMDB_LOCALE"}
|
||||
|
||||
# 元数据缓存
|
||||
cache: TmdbCache = None
|
||||
@@ -39,6 +39,12 @@ class TheMovieDbModule(_ModuleBase):
|
||||
self.category = CategoryHelper()
|
||||
self.scraper = TmdbScraper()
|
||||
|
||||
def on_config_changed(self):
|
||||
# 停止模块
|
||||
self.stop()
|
||||
# 初始化模块
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TheMovieDb"
|
||||
@@ -635,7 +641,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
return medias
|
||||
return []
|
||||
|
||||
def search_persons(self, name: str) -> Optional[List[MediaPerson]]:
|
||||
def search_persons(self, name: str) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
搜索人物信息
|
||||
"""
|
||||
@@ -645,10 +651,10 @@ class TheMovieDbModule(_ModuleBase):
|
||||
return []
|
||||
results = self.tmdb.search_persons(name)
|
||||
if results:
|
||||
return [MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return [schemas.MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return []
|
||||
|
||||
async def async_search_persons(self, name: str) -> Optional[List[MediaPerson]]:
|
||||
async def async_search_persons(self, name: str) -> Optional[List[schemas.MediaPerson]]:
|
||||
"""
|
||||
异步搜索人物信息
|
||||
"""
|
||||
@@ -658,7 +664,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
return []
|
||||
results = await self.tmdb.async_search_persons(name)
|
||||
if results:
|
||||
return [MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return [schemas.MediaPerson(source='themoviedb', **person) for person in results]
|
||||
return []
|
||||
|
||||
def search_collections(self, name: str) -> Optional[List[MediaInfo]]:
|
||||
|
||||
@@ -643,17 +643,23 @@ class TmdbApi:
|
||||
reverse=True
|
||||
)
|
||||
for tv in tvs:
|
||||
# 年份
|
||||
# 使用年份、名称匹配
|
||||
tv_year = tv.get('first_air_date')[0:4] if tv.get('first_air_date') else None
|
||||
if (self.__compare_names(name, tv.get('name'))
|
||||
or self.__compare_names(name, tv.get('original_name'))) \
|
||||
and (tv_year == str(season_year)):
|
||||
return tv
|
||||
# 匹配别名、译名
|
||||
# 获取别名、译名重新匹配
|
||||
if not tv.get("names"):
|
||||
tv = self.get_info(mtype=MediaType.TV, tmdbid=tv.get("id"))
|
||||
if not tv or not self.__compare_names(name, tv.get("names")):
|
||||
if not tv or not (
|
||||
self.__compare_names(name, tv.get("name"))
|
||||
or self.__compare_names(name, tv.get("original_name"))
|
||||
or self.__compare_names(name, tv.get("names"))):
|
||||
continue
|
||||
if tv_year == str(season_year):
|
||||
return tv
|
||||
# 季年份匹配
|
||||
if __season_match(tv_info=tv, _season_year=season_year):
|
||||
return tv
|
||||
return {}
|
||||
@@ -744,11 +750,11 @@ class TmdbApi:
|
||||
if validation_result is not None:
|
||||
return validation_result
|
||||
|
||||
logger.info("正在从TheDbMovie网站查询:%s ..." % name)
|
||||
logger.info("正在从TheMovieDb网站查询:%s ..." % name)
|
||||
tmdb_url = self._build_tmdb_search_url(name)
|
||||
res = RequestUtils(timeout=5, ua=settings.NORMAL_USER_AGENT, proxies=settings.PROXY).get_res(url=tmdb_url)
|
||||
if res is None:
|
||||
logger.error("无法连接TheDbMovie")
|
||||
logger.error("无法连接TheMovieDb")
|
||||
return None
|
||||
|
||||
# 响应验证
|
||||
|
||||
@@ -7,13 +7,12 @@ from transmission_rpc import File
|
||||
from app import schemas
|
||||
from app.core.cache import FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _DownloaderBase
|
||||
from app.modules.transmission.transmission import Transmission
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType, SystemConfigKey, EventType
|
||||
from app.schemas.types import TorrentStatus, ModuleType, DownloaderType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
@@ -26,20 +25,6 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
super().init_service(service_name=Transmission.__name__.lower(),
|
||||
service_type=Transmission)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Downloaders.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Transmission模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Transmission"
|
||||
@@ -140,12 +125,12 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
return None, None, None, "下载内容为空"
|
||||
|
||||
# 读取种子的名称
|
||||
torrent, content = __get_torrent_info()
|
||||
torrent_from_file, content = __get_torrent_info()
|
||||
# 检查是否为磁力链接
|
||||
is_magnet = isinstance(content, str) and content.startswith("magnet:") or isinstance(content,
|
||||
bytes) and content.startswith(
|
||||
b"magnet:")
|
||||
if not torrent and not is_magnet:
|
||||
if not torrent_from_file and not is_magnet:
|
||||
return None, None, None, f"添加种子任务失败:无法读取种子文件"
|
||||
|
||||
# 获取下载器
|
||||
@@ -164,9 +149,9 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
else:
|
||||
labels = None
|
||||
# 添加任务
|
||||
torrent = server.add_torrent(
|
||||
added_torrent = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=str(download_dir),
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
labels=labels,
|
||||
cookie=cookie
|
||||
@@ -174,7 +159,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
# TR 始终使用原始种子布局, 返回"Original"
|
||||
torrent_layout = "Original"
|
||||
|
||||
if not torrent:
|
||||
if not added_torrent:
|
||||
# 查询所有下载器的种子
|
||||
torrents, error = server.get_torrents()
|
||||
if error:
|
||||
@@ -183,7 +168,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
try:
|
||||
for torrent in torrents:
|
||||
# 名称与大小相等则认为是同一个种子
|
||||
if torrent.name == torrent.name and torrent.total_size == torrent.total_size:
|
||||
if torrent.name == getattr(torrent_from_file, 'name', '') and torrent.total_size == getattr(torrent_from_file, 'total_size', 0):
|
||||
torrent_hash = torrent.hashString
|
||||
logger.warn(f"下载器中已存在该种子任务:{torrent_hash} - {torrent.name}")
|
||||
# 给种子打上标签
|
||||
@@ -204,7 +189,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
|
||||
del torrents
|
||||
return None, None, None, f"添加种子任务失败:{content}"
|
||||
else:
|
||||
torrent_hash = torrent.hashString
|
||||
torrent_hash = added_torrent.hashString
|
||||
if is_paused:
|
||||
# 选择文件
|
||||
torrent_files = server.get_files(torrent_hash)
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.trimemedia.trimemedia import TrimeMedia
|
||||
from app.schemas import AuthCredentials, AuthInterceptCredentials
|
||||
from app.schemas.types import ChainEventType, MediaServerType, MediaType, ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas.types import ChainEventType, MediaServerType, MediaType, ModuleType
|
||||
|
||||
|
||||
class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
@@ -23,20 +23,6 @@ class TrimeMediaModule(_ModuleBase, _MediaServerBase[TrimeMedia]):
|
||||
),
|
||||
)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.MediaServers.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载飞牛影视模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "飞牛影视"
|
||||
|
||||
@@ -449,7 +449,7 @@ class TrimeMedia:
|
||||
item_type=item_type,
|
||||
title=item.title,
|
||||
original_title=item.original_title,
|
||||
year=str(year),
|
||||
year=year,
|
||||
tmdbid=item.tmdb_id,
|
||||
imdbid=item.imdb_id,
|
||||
user_state=user_state,
|
||||
|
||||
@@ -2,12 +2,11 @@ import json
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.vocechat.vocechat import VoceChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
@@ -20,20 +19,6 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
service_type=VoceChat)
|
||||
self._channel = MessageChannel.VoceChat
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载VoceChat模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "VoceChat"
|
||||
|
||||
@@ -4,11 +4,10 @@ from typing import Union, Tuple
|
||||
from pywebpush import webpush, WebPushException
|
||||
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import Notification, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, MessageChannel, SystemConfigKey, EventType
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import ModuleType, MessageChannel
|
||||
|
||||
|
||||
class WebPushModule(_ModuleBase, _MessageBase):
|
||||
@@ -20,20 +19,6 @@ class WebPushModule(_ModuleBase, _MessageBase):
|
||||
super().init_service(service_name=self.get_name().lower())
|
||||
self._channel = MessageChannel.WebPush
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载WebPush模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "WebPush"
|
||||
|
||||
@@ -3,13 +3,13 @@ import xml.dom.minidom
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
from app.modules.wechat.wechat import WeChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData, ConfigChangeEventData
|
||||
from app.schemas.types import ModuleType, ChainEventType, SystemConfigKey, EventType
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData
|
||||
from app.schemas.types import ModuleType, ChainEventType
|
||||
from app.utils.dom import DomUtils
|
||||
from app.utils.structures import DictUtils
|
||||
|
||||
@@ -24,20 +24,6 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
service_type=WeChat)
|
||||
self._channel = MessageChannel.Wechat
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Notifications.value]:
|
||||
return
|
||||
logger.info("配置变更,重新加载Wechat模块...")
|
||||
self.init_module()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "微信"
|
||||
|
||||
@@ -17,13 +17,12 @@ from app.chain.storage import StorageChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.cache import TTLCache, FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas import FileItem
|
||||
from app.schemas.types import SystemConfigKey, EventType
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import SingletonClass
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
@@ -60,10 +59,11 @@ class FileMonitorHandler(FileSystemEventHandler):
|
||||
logger.error(f"on_moved 异常: {e}")
|
||||
|
||||
|
||||
class Monitor(metaclass=SingletonClass):
|
||||
class Monitor(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
"""
|
||||
目录监控处理链,单例模式
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Directories.value}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -84,20 +84,12 @@ class Monitor(metaclass=SingletonClass):
|
||||
# 启动目录监控和文件整理
|
||||
self.init()
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in [SystemConfigKey.Directories.value]:
|
||||
return
|
||||
logger.info("配置变更事件触发,重新初始化目录监控...")
|
||||
def on_config_changed(self):
|
||||
self.init()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "目录监控"
|
||||
|
||||
def save_snapshot(self, storage: str, snapshot: Dict, file_count: int = 0,
|
||||
last_snapshot_time: Optional[float] = None):
|
||||
"""
|
||||
|
||||
@@ -22,16 +22,17 @@ from app.chain.subscribe import SubscribeChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.wallpaper import WallpaperHelper
|
||||
from app.helper.image import WallpaperHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType, Workflow, ConfigChangeEventData
|
||||
from app.schemas import Notification, NotificationType, Workflow
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.gc import get_memory_usage
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import SingletonClass
|
||||
from app.utils.timer import TimerUtils
|
||||
|
||||
@@ -42,10 +43,20 @@ class SchedulerChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class Scheduler(metaclass=SingletonClass):
|
||||
class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
"""
|
||||
定时任务管理
|
||||
"""
|
||||
CONFIG_WATCH = {
|
||||
"DEV",
|
||||
"COOKIECLOUD_INTERVAL",
|
||||
"MEDIASERVER_SYNC_INTERVAL",
|
||||
"SUBSCRIBE_SEARCH",
|
||||
"SUBSCRIBE_SEARCH_INTERVAL",
|
||||
"SUBSCRIBE_MODE",
|
||||
"SUBSCRIBE_RSS_INTERVAL",
|
||||
"SITEDATA_REFRESH_INTERVAL",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
# 定时服务
|
||||
@@ -63,22 +74,12 @@ class Scheduler(metaclass=SingletonClass):
|
||||
# 初始化
|
||||
self.init()
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEV', 'COOKIECLOUD_INTERVAL', 'MEDIASERVER_SYNC_INTERVAL', 'SUBSCRIBE_SEARCH',
|
||||
'SUBSCRIBE_SEARCH_INTERVAL', 'SUBSCRIBE_MODE', 'SUBSCRIBE_RSS_INTERVAL',
|
||||
'SITEDATA_REFRESH_INTERVAL']:
|
||||
return
|
||||
logger.info(f"配置项 {event_data.key} 变更,重新初始化定时服务...")
|
||||
def on_config_changed(self):
|
||||
self.init()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "定时服务"
|
||||
|
||||
def init(self):
|
||||
"""
|
||||
初始化定时服务
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Set, Callable
|
||||
from typing import Iterable, Optional, Dict, Any, List, Set, Callable
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from app.schemas.message import MessageChannel
|
||||
from app.schemas.file import FileItem
|
||||
@@ -27,10 +27,29 @@ class ConfigChangeEventData(BaseEventData):
|
||||
"""
|
||||
ConfigChange 事件的数据模型
|
||||
"""
|
||||
key: str = Field(..., description="配置项的键")
|
||||
key: set[str] = Field(..., description="配置项的键(集合类型)")
|
||||
value: Optional[Any] = Field(default=None, description="配置项的新值")
|
||||
change_type: str = Field(default="update", description="配置项的变更类型,如 'add', 'update', 'delete'")
|
||||
|
||||
@field_validator('key', mode='before')
|
||||
@classmethod
|
||||
def convert_to_set(cls, v):
|
||||
"""将输入的 str、list、dict.keys() 等转为 set"""
|
||||
if v is None:
|
||||
return set()
|
||||
elif isinstance(v, str):
|
||||
return {v}
|
||||
elif isinstance(v, dict):
|
||||
return set(str(k) for k in v.keys())
|
||||
elif isinstance(v, (list, tuple)):
|
||||
return set(str(item) for item in v)
|
||||
elif isinstance(v, set):
|
||||
return set(str(item) for item in v)
|
||||
elif isinstance(v, Iterable):
|
||||
return set(str(item) for item in v)
|
||||
else:
|
||||
return {str(v)}
|
||||
|
||||
|
||||
class ChainEventData(BaseEventData):
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
from app.schemas.types import StorageSchema
|
||||
|
||||
|
||||
class FileItem(BaseModel):
|
||||
# 存储类型
|
||||
storage: Optional[str] = Field(default="local")
|
||||
# 类型 dir/file
|
||||
type: Optional[str] = None
|
||||
class FileURI(BaseModel):
|
||||
# 文件路径
|
||||
path: Optional[str] = "/"
|
||||
# 存储类型
|
||||
storage: Optional[str] = Field(default="local")
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self.path if self.storage == "local" else f"{self.storage}:{self.path}"
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri: str) -> "FileURI":
|
||||
storage, path = 'local', uri
|
||||
for s in StorageSchema:
|
||||
protocol = f"{s.value}:"
|
||||
if uri.startswith(protocol):
|
||||
path = uri[len(protocol):]
|
||||
storage = s.value
|
||||
break
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = Path(path).as_posix()
|
||||
return cls(storage=storage, path=path)
|
||||
|
||||
class FileItem(FileURI):
|
||||
# 类型 dir/file
|
||||
type: Optional[str] = None
|
||||
# 文件名
|
||||
name: Optional[str] = None
|
||||
# 文件名
|
||||
@@ -46,3 +68,4 @@ class StorageUsage(BaseModel):
|
||||
class StorageTransType(BaseModel):
|
||||
# 传输类型
|
||||
transtype: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class RefreshMediaItem(BaseModel):
|
||||
# 标题
|
||||
title: Optional[str] = None
|
||||
# 年份
|
||||
year: Optional[str] = None
|
||||
year: Optional[Union[str, int]] = None
|
||||
# 类型
|
||||
type: Optional[MediaType] = None
|
||||
# 类别
|
||||
@@ -110,7 +110,7 @@ class MediaServerItem(BaseModel):
|
||||
# 原标题
|
||||
original_title: Optional[str] = None
|
||||
# 年份
|
||||
year: Optional[str] = None
|
||||
year: Optional[Union[str, int]] = None
|
||||
# TMDBID
|
||||
tmdbid: Optional[int] = None
|
||||
# IMDBID
|
||||
|
||||
@@ -70,7 +70,7 @@ class Notification(BaseModel):
|
||||
# 用户ID
|
||||
userid: Optional[Union[str, int]] = None
|
||||
# 用户名称
|
||||
username: Optional[str] = None
|
||||
username: Optional[Union[str, int]] = None
|
||||
# 时间
|
||||
date: Optional[str] = None
|
||||
# 消息方向
|
||||
@@ -221,6 +221,22 @@ class ChannelCapabilityManager:
|
||||
max_button_text_length=25,
|
||||
fallback_enabled=True
|
||||
),
|
||||
MessageChannel.Discord: ChannelCapabilities(
|
||||
channel=MessageChannel.Discord,
|
||||
capabilities={
|
||||
ChannelCapability.INLINE_BUTTONS,
|
||||
ChannelCapability.MESSAGE_EDITING,
|
||||
ChannelCapability.MESSAGE_DELETION,
|
||||
ChannelCapability.CALLBACK_QUERIES,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS
|
||||
},
|
||||
max_buttons_per_row=5,
|
||||
max_button_rows=5,
|
||||
max_button_text_length=80,
|
||||
fallback_enabled=True
|
||||
),
|
||||
MessageChannel.SynologyChat: ChannelCapabilities(
|
||||
channel=MessageChannel.SynologyChat,
|
||||
capabilities={
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
class RadarrMovie(BaseModel):
|
||||
id: Optional[int] = None
|
||||
title: Optional[str] = None
|
||||
year: Optional[str] = None
|
||||
year: Optional[str | int] = None
|
||||
isAvailable: bool = False
|
||||
monitored: bool = False
|
||||
tmdbId: Optional[int] = None
|
||||
@@ -31,7 +31,7 @@ class SonarrSeries(BaseModel):
|
||||
images: list = Field(default_factory=list)
|
||||
remotePoster: Optional[str] = None
|
||||
seasons: list = Field(default_factory=list)
|
||||
year: Optional[str] = None
|
||||
year: Optional[str | int] = None
|
||||
path: Optional[str] = None
|
||||
profileId: Optional[int] = None
|
||||
languageProfileId: Optional[int] = None
|
||||
|
||||
@@ -51,6 +51,8 @@ class DownloaderConf(BaseModel):
|
||||
config: Optional[dict] = Field(default_factory=dict)
|
||||
# 是否启用
|
||||
enabled: Optional[bool] = False
|
||||
# 路径映射
|
||||
path_mapping: Optional[list[tuple[str, str]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class NotificationConf(BaseModel):
|
||||
|
||||
@@ -21,7 +21,7 @@ class Token(BaseModel):
|
||||
# 详细权限
|
||||
permissions: Optional[dict] = Field(default_factory=dict)
|
||||
# 是否显示配置向导
|
||||
widzard: Optional[bool] = None
|
||||
wizard: Optional[bool] = None
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
|
||||
@@ -265,6 +265,7 @@ class MessageChannel(Enum):
|
||||
Wechat = "微信"
|
||||
Telegram = "Telegram"
|
||||
Slack = "Slack"
|
||||
Discord = "Discord"
|
||||
SynologyChat = "SynologyChat"
|
||||
VoceChat = "VoceChat"
|
||||
Web = "Web"
|
||||
|
||||
66
app/utils/mixins.py
Normal file
66
app/utils/mixins.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import inspect
|
||||
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
class ConfigReloadMixin:
|
||||
"""配置重载混入类
|
||||
|
||||
继承此 Mixin 类的类,会在配置变更时自动调用 on_config_changed 方法。
|
||||
在类中定义 CONFIG_WATCH 集合,指定需要监听的配置项
|
||||
重写 on_config_changed 方法实现具体的重载逻辑
|
||||
可选地重写 get_reload_name 方法提供模块名称(用于日志显示)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
config_watch = getattr(cls, 'CONFIG_WATCH', None)
|
||||
if not config_watch:
|
||||
return
|
||||
|
||||
# 检查 on_config_changed 方法是否为异步
|
||||
is_async = inspect.iscoroutinefunction(cls.on_config_changed)
|
||||
|
||||
method_name = 'handle_config_changed'
|
||||
|
||||
# 创建事件处理函数
|
||||
def create_handler(is_async):
|
||||
if is_async:
|
||||
async def wrapper(self: ConfigReloadMixin, event: Event):
|
||||
if not event:
|
||||
return
|
||||
changed_keys = getattr(event.event_data, "key", set()) & config_watch
|
||||
if not changed_keys:
|
||||
return
|
||||
logger.info(f"配置 {', '.join(changed_keys)} 变更,重载 {self.get_reload_name()}...")
|
||||
await self.on_config_changed()
|
||||
else:
|
||||
def wrapper(self: ConfigReloadMixin, event: Event):
|
||||
if not event:
|
||||
return
|
||||
changed_keys = getattr(event.event_data, "key", set()) & config_watch
|
||||
if not changed_keys:
|
||||
return
|
||||
logger.info(f"配置 {', '.join(changed_keys)} 变更,重载 {self.get_reload_name()}...")
|
||||
self.on_config_changed()
|
||||
|
||||
return wrapper
|
||||
|
||||
# 创建并设置处理函数
|
||||
handler = create_handler(is_async)
|
||||
handler.__module__ = cls.__module__
|
||||
handler.__qualname__ = f'{cls.__name__}.{method_name}'
|
||||
setattr(cls, method_name, handler)
|
||||
# 添加为事件处理器
|
||||
eventmanager.add_event_listener(EventType.ConfigChanged, handler)
|
||||
|
||||
def on_config_changed(self):
|
||||
"""子类重写此方法实现具体重载逻辑"""
|
||||
pass
|
||||
|
||||
def get_reload_name(self):
|
||||
"""功能/模块名称"""
|
||||
return self.__class__.__name__
|
||||
@@ -16,7 +16,7 @@ class AddDownloadParams(ActionParams):
|
||||
添加下载资源参数
|
||||
"""
|
||||
downloader: Optional[str] = Field(default=None, description="下载器")
|
||||
save_path: Optional[str] = Field(default=None, description="保存路径")
|
||||
save_path: Optional[str] = Field(default=None, description="保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等")
|
||||
labels: Optional[str] = Field(default=None, description="标签(,分隔)")
|
||||
only_lack: Optional[bool] = Field(default=False, description="仅下载缺失的资源")
|
||||
|
||||
|
||||
@@ -94,6 +94,7 @@ COPY --from=prepare_venv --chmod=777 ${VENV_PATH} ${VENV_PATH}
|
||||
|
||||
# playwright 环境
|
||||
RUN playwright install-deps chromium \
|
||||
&& playwright install-deps firefox \
|
||||
&& apt-get autoremove -y \
|
||||
&& apt-get clean \
|
||||
&& rm -rf \
|
||||
|
||||
@@ -231,9 +231,9 @@ chown moviepilot:moviepilot /etc/hosts /tmp
|
||||
|
||||
# 下载浏览器内核
|
||||
if [[ "$HTTPS_PROXY" =~ ^https?:// ]] || [[ "$HTTPS_PROXY" =~ ^https?:// ]] || [[ "$PROXY_HOST" =~ ^https?:// ]]; then
|
||||
HTTPS_PROXY="${HTTPS_PROXY:-${https_proxy:-$PROXY_HOST}}" gosu moviepilot:moviepilot playwright install chromium
|
||||
HTTPS_PROXY="${HTTPS_PROXY:-${https_proxy:-$PROXY_HOST}}" gosu moviepilot:moviepilot playwright install ${PLAYWRIGHT_BROWSER_TYPE:-chromium}
|
||||
else
|
||||
gosu moviepilot:moviepilot playwright install chromium
|
||||
gosu moviepilot:moviepilot playwright install ${PLAYWRIGHT_BROWSER_TYPE:-chromium}
|
||||
fi
|
||||
|
||||
# 证书管理
|
||||
|
||||
213
docs/mcp-api.md
213
docs/mcp-api.md
@@ -1,9 +1,77 @@
|
||||
# MoviePilot 工具API文档
|
||||
# MoviePilot MCP (Model Context Protocol) API 文档
|
||||
|
||||
MoviePilot的智能体工具已通过HTTP API暴露,可以通过RESTful API调用所有工具。
|
||||
MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智能体(如 Claude, GPT 等)直接调用 MoviePilot 的功能进行媒体管理、搜索、订阅和下载。
|
||||
|
||||
## API端点
|
||||
## 1. 基础信息
|
||||
|
||||
* **基础路径**: `/api/v1/mcp`
|
||||
* **协议版本**: `2025-11-25, 2025-06-18, 2024-11-05`
|
||||
* **传输协议**: HTTP (JSON-RPC 2.0)
|
||||
* **认证方式**:
|
||||
* Header: `X-API-KEY: <你的API_KEY>`
|
||||
* Query: `?apikey=<你的API_KEY>`
|
||||
|
||||
## 2. 标准 MCP 协议 (JSON-RPC 2.0)
|
||||
|
||||
### 端点
|
||||
**POST** `/api/v1/mcp`
|
||||
|
||||
### 支持的方法
|
||||
- `initialize`: 初始化会话,协商协议版本和能力。
|
||||
- `notifications/initialized`: 客户端确认初始化完成。
|
||||
- `tools/list`: 获取可用工具列表。
|
||||
- `tools/call`: 调用特定工具。
|
||||
- `ping`: 连接存活检测。
|
||||
|
||||
---
|
||||
|
||||
## 4. 客户端配置示例
|
||||
|
||||
### Claude Desktop (Anthropic)
|
||||
|
||||
在Claude Desktop的配置文件中添加MoviePilot的MCP服务器配置:
|
||||
|
||||
**macOS**: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
**Windows**: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
|
||||
使用请求头方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp",
|
||||
"headers": {
|
||||
"X-API-KEY": "your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
或使用查询参数方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp?apikey=your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 5. 错误码说明
|
||||
|
||||
| 错误码 | 消息 | 说明 |
|
||||
| :--- | :--- | :--- |
|
||||
| -32700 | Parse error | JSON 格式错误 |
|
||||
| -32600 | Invalid Request | 无效的 JSON-RPC 请求 |
|
||||
| -32601 | Method not found | 方法不存在 |
|
||||
| -32602 | Invalid params | 参数验证失败 |
|
||||
| -32002 | Session not found | 会话不存在或已过期 |
|
||||
| -32003 | Not initialized | 会话未完成初始化流程 |
|
||||
| -32603 | Internal error | 服务器内部错误 |
|
||||
|
||||
## 6. RESTful API
|
||||
所有工具相关的API端点都在 `/api/v1/mcp` 路径下(保持向后兼容)。
|
||||
|
||||
### 1. 列出所有工具
|
||||
@@ -137,142 +205,3 @@ MoviePilot的智能体工具已通过HTTP API暴露,可以通过RESTful API调
|
||||
"required": ["title", "year", "media_type"]
|
||||
}
|
||||
```
|
||||
|
||||
## MCP客户端配置
|
||||
|
||||
MoviePilot的MCP工具可以通过HTTP协议在支持MCP的客户端中使用。以下是常见MCP客户端的配置方法:
|
||||
|
||||
### Claude Desktop (Anthropic)
|
||||
|
||||
在Claude Desktop的配置文件中添加MoviePilot的MCP服务器配置:
|
||||
|
||||
**macOS**: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
**Windows**: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-http",
|
||||
"http://localhost:3001/api/v1/mcp"
|
||||
],
|
||||
"env": {
|
||||
"X-API-KEY": "your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**注意**: 如果MCP HTTP服务器不支持环境变量传递API Key,可以使用查询参数方式:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-http",
|
||||
"http://localhost:3001/api/v1/mcp?apikey=your_api_key_here"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 其他支持MCP的聊天客户端
|
||||
|
||||
对于其他支持MCP协议的聊天客户端(如其他AI聊天助手、对话机器人等),通常可以通过配置文件或设置界面添加HTTP协议的MCP服务器。配置格式可能因客户端而异,但通常需要以下信息:
|
||||
|
||||
**配置参数**:
|
||||
1. **服务器类型**: HTTP
|
||||
2. **服务器地址**: `http://your-moviepilot-host:3001/api/v1/mcp`
|
||||
3. **认证方式**:
|
||||
- 在HTTP请求头中添加 `X-API-KEY: <your_api_key>`
|
||||
- 或在URL查询参数中添加 `apikey=<your_api_key>`
|
||||
|
||||
**示例配置**(通用格式):
|
||||
|
||||
使用请求头方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp",
|
||||
"headers": {
|
||||
"X-API-KEY": "your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
或使用查询参数方式:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"moviepilot": {
|
||||
"url": "http://localhost:3001/api/v1/mcp?apikey=your_api_key_here"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**支持的端点**:
|
||||
- `GET /tools` - 列出所有工具
|
||||
- `POST /tools/call` - 调用工具
|
||||
- `GET /tools/{tool_name}` - 获取工具详情
|
||||
- `GET /tools/{tool_name}/schema` - 获取工具参数Schema
|
||||
|
||||
配置完成后,您就可以在聊天对话中使用MoviePilot的各种工具,例如:
|
||||
- 添加媒体订阅
|
||||
- 查询下载历史
|
||||
- 搜索媒体资源
|
||||
- 管理媒体服务器
|
||||
- 等等...
|
||||
|
||||
### 获取API Key
|
||||
|
||||
API Key可以在MoviePilot的系统设置中生成和查看。请妥善保管您的API Key,不要泄露给他人。
|
||||
|
||||
## 认证
|
||||
|
||||
所有MCP API端点都需要认证。**仅支持API Key认证方式**:
|
||||
|
||||
- **请求头方式**: 在请求头中添加 `X-API-KEY: <api_key>`
|
||||
- **查询参数方式**: 在URL查询参数中添加 `apikey=<api_key>`
|
||||
|
||||
**获取API Key**: 在MoviePilot系统设置中生成和查看API Key。请妥善保管您的API Key,不要泄露给他人。
|
||||
|
||||
## 错误处理
|
||||
|
||||
API会返回标准的HTTP状态码:
|
||||
|
||||
- `200 OK`: 请求成功
|
||||
- `400 Bad Request`: 请求参数错误
|
||||
- `401 Unauthorized`: 未认证或API Key无效
|
||||
- `404 Not Found`: 工具不存在
|
||||
- `500 Internal Server Error`: 服务器内部错误
|
||||
|
||||
错误响应格式:
|
||||
```json
|
||||
{
|
||||
"detail": "错误描述信息"
|
||||
}
|
||||
```
|
||||
|
||||
## 架构说明
|
||||
|
||||
工具API通过FastAPI端点暴露,使用HTTP协议与客户端通信。所有工具共享相同的实现,确保功能一致性。
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **用户上下文**: API调用会使用当前认证用户的ID作为工具执行的用户上下文
|
||||
2. **会话隔离**: 每个API请求使用独立的会话ID
|
||||
3. **参数验证**: 工具参数会根据JSON Schema进行验证
|
||||
4. **错误日志**: 所有工具调用错误都会记录到MoviePilot日志系统
|
||||
|
||||
|
||||
6
package-lock.json
generated
6
package-lock.json
generated
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"name": "MoviePilot",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
@@ -43,6 +43,7 @@ cf_clearance~=0.31.0
|
||||
torrentool~=1.2.0
|
||||
slack-bolt~=1.23.0
|
||||
slack-sdk~=3.35.0
|
||||
discord.py==2.6.4
|
||||
chardet~=5.2.0
|
||||
starlette~=0.46.2
|
||||
PyVirtualDisplay~=3.0
|
||||
@@ -61,6 +62,7 @@ cachetools~=6.1.0
|
||||
fast-bencode~=1.1.7
|
||||
pystray~=0.19.5
|
||||
pyotp~=2.9.0
|
||||
webauthn~=2.7.0
|
||||
Pinyin2Hanzi~=0.1.1
|
||||
pywebpush~=2.0.3
|
||||
aiopathlib~=0.6.0
|
||||
@@ -87,4 +89,5 @@ langchain-openai~=0.3.33
|
||||
langchain-google-genai~=2.0.10
|
||||
langchain-deepseek~=0.1.4
|
||||
langchain-experimental~=0.3.4
|
||||
openai~=1.108.2
|
||||
openai~=1.108.2
|
||||
google-generativeai~=0.8.5
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
# 工具命名和功能混淆分析报告
|
||||
|
||||
## 一、潜在混淆点分析
|
||||
|
||||
### 1. 媒体相关工具(高风险混淆)
|
||||
|
||||
#### 1.1 `search_media` vs `recognize_media` vs `scrape_metadata`
|
||||
**当前状态:**
|
||||
- `search_media`: 在TMDB数据库中搜索媒体(按标题、年份等条件搜索)
|
||||
- `recognize_media`: 从种子标题或文件路径识别媒体信息
|
||||
- `scrape_metadata`: 为文件或目录刮削元数据(生成NFO文件等)
|
||||
|
||||
**混淆风险:** ⚠️ **中等**
|
||||
- 三个工具都涉及"媒体",但功能完全不同
|
||||
- `search_media` 是数据库搜索,`recognize_media` 是信息识别,`scrape_metadata` 是文件操作
|
||||
- 描述已区分,但命名可能让智能体混淆
|
||||
|
||||
**建议:**
|
||||
- ✅ 描述已清晰区分,无需修改
|
||||
- 可考虑在描述中更明确强调差异:
|
||||
- `search_media`: "Search for media in TMDB database by title/year/type"
|
||||
- `recognize_media`: "Extract/identify media info from torrent titles or file paths"
|
||||
- `scrape_metadata`: "Generate metadata files (NFO, posters) for existing media files"
|
||||
|
||||
### 2. 目录相关工具(高风险混淆)
|
||||
|
||||
#### 2.1 `query_directories` vs `list_directory`
|
||||
**当前状态:**
|
||||
- `query_directories`: 查询系统目录配置(配置信息,如下载目录、媒体库目录的设置)
|
||||
- `list_directory`: 列出文件系统目录内容(实际文件和文件夹)
|
||||
|
||||
**混淆风险:** ⚠️⚠️ **高**
|
||||
- 两个工具都涉及"目录",但一个是配置查询,一个是文件列表
|
||||
- 命名相似,容易混淆
|
||||
|
||||
**建议:**
|
||||
- 考虑重命名 `query_directories` 为 `query_directory_config` 或 `query_directory_settings`
|
||||
- 或者在描述中更明确:
|
||||
- `query_directories`: "Query directory **configuration settings** (download/library paths, storage types, etc.)"
|
||||
- `list_directory`: "List **actual files and folders** in a file system directory"
|
||||
|
||||
### 3. 媒体库查询工具(低风险)
|
||||
|
||||
#### 3.1 `query_media_library` vs `query_media_latest`
|
||||
**当前状态:**
|
||||
- `query_media_library`: 检查特定媒体是否存在于媒体库中
|
||||
- `query_media_latest`: 查询媒体服务器最近入库的影片列表
|
||||
|
||||
**混淆风险:** ✅ **低**
|
||||
- 命名清晰,功能区分明确
|
||||
- 描述已明确说明差异
|
||||
|
||||
### 4. 订阅相关工具(低风险)
|
||||
|
||||
#### 4.1 `search_subscribe` vs `query_subscribes`
|
||||
**当前状态:**
|
||||
- `search_subscribe`: 搜索订阅的缺失剧集(执行搜索和下载操作)
|
||||
- `query_subscribes`: 查询订阅状态和列表(只查询,不执行操作)
|
||||
|
||||
**混淆风险:** ✅ **低**
|
||||
- `search` vs `query` 已明确区分操作类型
|
||||
- 描述清晰
|
||||
|
||||
#### 4.2 多个订阅查询工具
|
||||
- `query_subscribes`: 查询所有订阅
|
||||
- `query_subscribe_history`: 查询订阅历史
|
||||
- `query_subscribe_shares`: 查询共享订阅
|
||||
- `query_popular_subscribes`: 查询热门订阅
|
||||
|
||||
**混淆风险:** ✅ **低**
|
||||
- 命名清晰,通过后缀区分功能
|
||||
|
||||
### 5. 下载相关工具(低风险)
|
||||
|
||||
#### 5.1 `query_downloads` vs `query_downloaders`
|
||||
**当前状态:**
|
||||
- `query_downloads`: 查询下载任务状态
|
||||
- `query_downloaders`: 查询下载器配置
|
||||
|
||||
**混淆风险:** ✅ **低**
|
||||
- 命名清晰,单复数已区分
|
||||
|
||||
### 6. 调度器/工作流工具(低风险)
|
||||
|
||||
#### 6.1 `query_schedulers` vs `run_scheduler`
|
||||
**当前状态:**
|
||||
- `query_schedulers`: 查询调度器任务列表
|
||||
- `run_scheduler`: 运行调度器任务
|
||||
|
||||
**混淆风险:** ✅ **低**
|
||||
- `query` vs `run` 已明确区分
|
||||
|
||||
#### 6.2 `query_workflows` vs `run_workflow`
|
||||
**当前状态:**
|
||||
- `query_workflows`: 查询工作流列表
|
||||
- `run_workflow`: 运行工作流
|
||||
|
||||
**混淆风险:** ✅ **低**
|
||||
- `query` vs `run` 已明确区分
|
||||
|
||||
## 二、命名规范分析
|
||||
|
||||
### 命名模式总结:
|
||||
1. **Search类** (`search_*`): 执行搜索操作
|
||||
- `search_media`, `search_person`, `search_torrents`, `search_subscribe`, `search_web`
|
||||
|
||||
2. **Query类** (`query_*`): 查询信息(只读)
|
||||
- `query_*`: 各种查询工具
|
||||
|
||||
3. **Add/Update/Delete类**: 执行操作
|
||||
- `add_*`, `update_*`, `delete_*`
|
||||
|
||||
4. **Run类**: 执行任务
|
||||
- `run_scheduler`, `run_workflow`
|
||||
|
||||
### 命名一致性:✅ 良好
|
||||
- 命名模式统一,易于理解
|
||||
- 动词选择恰当(search/query/add/update/delete/run)
|
||||
|
||||
## 三、建议优化
|
||||
|
||||
### 高优先级(建议修改)
|
||||
|
||||
1. **`query_directories` 重命名或增强描述**
|
||||
- 建议:在描述中明确强调是"配置查询"而非"文件列表"
|
||||
- 或考虑重命名为 `query_directory_config`
|
||||
|
||||
### 中优先级(可选优化)
|
||||
|
||||
1. **`search_media`、`recognize_media`、`scrape_metadata` 的描述增强**
|
||||
- 在描述开头明确说明数据源/操作类型
|
||||
- 例如:`search_media`: "Search TMDB database for media..."
|
||||
- 例如:`recognize_media`: "Extract media info from torrent titles or file paths..."
|
||||
- 例如:`scrape_metadata`: "Generate metadata files for existing media files..."
|
||||
|
||||
### 低优先级(当前已足够清晰)
|
||||
|
||||
- 其他工具命名和描述已足够清晰,无需修改
|
||||
|
||||
## 四、总结
|
||||
|
||||
**总体评价:** ✅ **良好**
|
||||
|
||||
大部分工具的命名和描述已经足够清晰,只有少数几个地方存在潜在的混淆风险:
|
||||
|
||||
1. **`query_directories` vs `list_directory`** - 需要更明确的区分
|
||||
2. **媒体相关三个工具** - 描述可以更明确强调差异
|
||||
|
||||
建议优先处理 `query_directories` 的描述或命名,其他工具当前状态已足够清晰。
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.8.7'
|
||||
FRONTEND_VERSION = 'v2.8.7'
|
||||
APP_VERSION = 'v2.8.9'
|
||||
FRONTEND_VERSION = 'v2.8.9'
|
||||
|
||||
Reference in New Issue
Block a user