mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-08 21:02:44 +08:00
Compare commits
445 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26e41e1c14 | ||
|
|
7bdb629f03 | ||
|
|
fd92f986da | ||
|
|
69a1207102 | ||
|
|
def652c768 | ||
|
|
c35faf5356 | ||
|
|
0615a33206 | ||
|
|
e77530bdc5 | ||
|
|
8c62df63cc | ||
|
|
bd36eade77 | ||
|
|
d2c023081a | ||
|
|
63d0850b38 | ||
|
|
c86659428f | ||
|
|
bf7cc6caf0 | ||
|
|
26b8be6041 | ||
|
|
f978f9196f | ||
|
|
75cb8d2a3c | ||
|
|
17a21ed707 | ||
|
|
f390647139 | ||
|
|
aacd91e196 | ||
|
|
258171c9c4 | ||
|
|
812c5873aa | ||
|
|
4c3d47f1f0 | ||
|
|
ba7b6ba869 | ||
|
|
d0471ae512 | ||
|
|
636c4be9fb | ||
|
|
6bec765a9d | ||
|
|
d61d16ccc4 | ||
|
|
f2a5715b24 | ||
|
|
c064c3781f | ||
|
|
bb4dffe2a4 | ||
|
|
37cf3eeef3 | ||
|
|
40395b2999 | ||
|
|
32afe6445f | ||
|
|
793a991913 | ||
|
|
d278224ff1 | ||
|
|
9b4d0ce6a8 | ||
|
|
a1829fe590 | ||
|
|
2b2b39365c | ||
|
|
1147930f3f | ||
|
|
636f338ed7 | ||
|
|
72365d00b4 | ||
|
|
19d8086732 | ||
|
|
30488418e5 | ||
|
|
2f0badd74a | ||
|
|
6045b0579b | ||
|
|
498f1fec74 | ||
|
|
f6a541f2b9 | ||
|
|
8ce78eabca | ||
|
|
2c34c5309f | ||
|
|
77e680168a | ||
|
|
8a7e59742f | ||
|
|
42bac14770 | ||
|
|
8323834483 | ||
|
|
1751caef62 | ||
|
|
d622d1474d | ||
|
|
f28be2e7de | ||
|
|
17773913ae | ||
|
|
d469c2d3f9 | ||
|
|
4e74d32882 | ||
|
|
7b8cd37a9b | ||
|
|
eda306d726 | ||
|
|
94f3b1fe84 | ||
|
|
c50e3ba293 | ||
|
|
eff7818912 | ||
|
|
270bcff8f3 | ||
|
|
e04963c2dc | ||
|
|
f369967c91 | ||
|
|
cd982c5526 | ||
|
|
16e03c9d37 | ||
|
|
d38b1f5364 | ||
|
|
f57ba4d05e | ||
|
|
172eeaafcf | ||
|
|
3115ed28b2 | ||
|
|
d8dc53805c | ||
|
|
7218d10e1b | ||
|
|
89bf85f501 | ||
|
|
8334a468d0 | ||
|
|
3da80ed077 | ||
|
|
2883ccbe87 | ||
|
|
5d3443fee4 | ||
|
|
27756a53db | ||
|
|
71cde6661d | ||
|
|
a857337b31 | ||
|
|
4ee21ffae4 | ||
|
|
d8399f7e85 | ||
|
|
574ac8d32f | ||
|
|
a2611bfa7d | ||
|
|
853badb76f | ||
|
|
5d69e1d2a5 | ||
|
|
6494f28bdb | ||
|
|
f55916bda2 | ||
|
|
04691ee197 | ||
|
|
2ac0e564e1 | ||
|
|
6072a29a20 | ||
|
|
8658942385 | ||
|
|
cc4859950c | ||
|
|
23b81ad6f1 | ||
|
|
e3b9dca5c0 | ||
|
|
a2359a1ad2 | ||
|
|
cb875b1b34 | ||
|
|
b92a85b4bc | ||
|
|
8c7dd6bab2 | ||
|
|
aad7df64d7 | ||
|
|
8474342007 | ||
|
|
61ccb4be65 | ||
|
|
1c6f69707c | ||
|
|
e08e8c482a | ||
|
|
548c1d2cab | ||
|
|
5a071bf3d1 | ||
|
|
1bffcbd947 | ||
|
|
274a36a83a | ||
|
|
ec40f36114 | ||
|
|
af19f274a7 | ||
|
|
2316004194 | ||
|
|
98762198ef | ||
|
|
1469de22a4 | ||
|
|
1e687f960a | ||
|
|
7f01b835fd | ||
|
|
e46b6c5c01 | ||
|
|
74226ad8df | ||
|
|
f8ae7be539 | ||
|
|
37b16e380d | ||
|
|
9ea3e9f652 | ||
|
|
54422b5181 | ||
|
|
712995dcf3 | ||
|
|
c2767b0fd6 | ||
|
|
179cc61f65 | ||
|
|
f3b910d55a | ||
|
|
f4157b52ea | ||
|
|
79710310ce | ||
|
|
3412498438 | ||
|
|
b896b07a08 | ||
|
|
379bff0622 | ||
|
|
474f47aa9f | ||
|
|
f1e26a4133 | ||
|
|
e37f881207 | ||
|
|
306c0b707b | ||
|
|
08c448ee30 | ||
|
|
1532014067 | ||
|
|
fa9f604af9 | ||
|
|
3b3d0d6539 | ||
|
|
9641d33040 | ||
|
|
eca339d107 | ||
|
|
ca18705d88 | ||
|
|
8f17b52466 | ||
|
|
8cf84e722b | ||
|
|
7c4d736b54 | ||
|
|
1b3ae6ab25 | ||
|
|
a4ad08136e | ||
|
|
df5e7997c5 | ||
|
|
b2cb3768c1 | ||
|
|
fa169c5cd3 | ||
|
|
bbb3975b67 | ||
|
|
4502a9c4fa | ||
|
|
86905a2670 | ||
|
|
b1e60a4867 | ||
|
|
1efe3324fb | ||
|
|
55c1e37d39 | ||
|
|
7fa700317c | ||
|
|
bbe831a57c | ||
|
|
90c86c056c | ||
|
|
36f22a28df | ||
|
|
ac03c51e2c | ||
|
|
bd9e92f705 | ||
|
|
281eff5eb2 | ||
|
|
abbd2253ad | ||
|
|
46466624ae | ||
|
|
0ba8d51b2a | ||
|
|
a1408ee18f | ||
|
|
58030bbcff | ||
|
|
e1b3e6ef01 | ||
|
|
298a6ba8ab | ||
|
|
e5bf47629f | ||
|
|
ea29ee9f66 | ||
|
|
868c2254de | ||
|
|
567522c87a | ||
|
|
25fd47f57b | ||
|
|
f89d6342d1 | ||
|
|
b02affdea3 | ||
|
|
6e5ade943b | ||
|
|
a6ed0c0d00 | ||
|
|
68402aadd7 | ||
|
|
85cacd447b | ||
|
|
11262b321a | ||
|
|
bf290f063d | ||
|
|
7ac0fbaf76 | ||
|
|
7489c76722 | ||
|
|
bcdf1b6efe | ||
|
|
8a9dbe212c | ||
|
|
16bd71a6cb | ||
|
|
71caad0655 | ||
|
|
2c62ffe34a | ||
|
|
3450a89880 | ||
|
|
a081a69bbe | ||
|
|
271d1d23d5 | ||
|
|
605aba1a3c | ||
|
|
be3c2b4c7c | ||
|
|
08eb32d7bd | ||
|
|
2b9cda15e4 | ||
|
|
f6055b290a | ||
|
|
ec665e05e4 | ||
|
|
2b6d7205ec | ||
|
|
41381a920c | ||
|
|
f1b3fc2254 | ||
|
|
a677ed307d | ||
|
|
0ab23ee972 | ||
|
|
43f56d39be | ||
|
|
a39caee5f5 | ||
|
|
2edfdf47c8 | ||
|
|
3819461db5 | ||
|
|
85654dd7dd | ||
|
|
619a70416b | ||
|
|
16d996fe70 | ||
|
|
1baeb6da19 | ||
|
|
1641d432dd | ||
|
|
1bf9862e47 | ||
|
|
602a394043 | ||
|
|
22a2415ca5 | ||
|
|
feb034352d | ||
|
|
a7c8942c78 | ||
|
|
95f2ac3811 | ||
|
|
91354295f2 | ||
|
|
c9c4ab5911 | ||
|
|
a26c5e40dd | ||
|
|
80f5c7bc44 | ||
|
|
4833b39c52 | ||
|
|
f478958943 | ||
|
|
0469ad46d6 | ||
|
|
5fe5deb9df | ||
|
|
ce83bc24bd | ||
|
|
dce729c8cb | ||
|
|
a9d17cd96f | ||
|
|
294bb3d4a1 | ||
|
|
b31b9261f2 | ||
|
|
2211f8d9e4 | ||
|
|
b9b7b00a7f | ||
|
|
843faf6103 | ||
|
|
4af5dad9a8 | ||
|
|
52437c9d18 | ||
|
|
c6cb4c8479 | ||
|
|
c3714ec251 | ||
|
|
dbe2f94af1 | ||
|
|
07fd5f8a9e | ||
|
|
9e64b4cd7f | ||
|
|
f08a7b9eb3 | ||
|
|
a6fa764e2a | ||
|
|
01676668f1 | ||
|
|
8e5e4f460d | ||
|
|
f907b8a84d | ||
|
|
a3a4285f90 | ||
|
|
0979163b79 | ||
|
|
248a25eaee | ||
|
|
f95b1fa68a | ||
|
|
d2b5d69051 | ||
|
|
3ca419b735 | ||
|
|
50e275a2f9 | ||
|
|
aeccf78957 | ||
|
|
cb3cef70e5 | ||
|
|
b9bd303bf8 | ||
|
|
57d4786a7f | ||
|
|
df031455b2 | ||
|
|
30059eff4f | ||
|
|
bc289b48c8 | ||
|
|
067d8b99b8 | ||
|
|
00a6a9c42d | ||
|
|
070425d446 | ||
|
|
7405883444 | ||
|
|
66959937ed | ||
|
|
e431efbcba | ||
|
|
ba00baa5a0 | ||
|
|
0fb5d4a164 | ||
|
|
1ac717b67f | ||
|
|
273cbd447e | ||
|
|
cee41567a2 | ||
|
|
1aae5eb1a6 | ||
|
|
28a4c81aff | ||
|
|
5e077cd64d | ||
|
|
e3f957a59b | ||
|
|
55c62a3ab5 | ||
|
|
22e7eef1bd | ||
|
|
d6524907f3 | ||
|
|
357db334cd | ||
|
|
f8bed3909b | ||
|
|
182bbdde91 | ||
|
|
2c70f990c2 | ||
|
|
0b01a6aa91 | ||
|
|
e557dffbc6 | ||
|
|
7f33b0b1b8 | ||
|
|
41ddf77a5b | ||
|
|
8c657ce41d | ||
|
|
3ff3b9ed4a | ||
|
|
ef43419ecd | ||
|
|
2ca375c214 | ||
|
|
cbd45c1d0f | ||
|
|
2592ea3464 | ||
|
|
73ac97cd96 | ||
|
|
e014663e97 | ||
|
|
58592e961f | ||
|
|
9a99b9ce82 | ||
|
|
8c6dca1751 | ||
|
|
cf488d5f5f | ||
|
|
515584d34c | ||
|
|
fb2becc7f2 | ||
|
|
0f8ceb0fac | ||
|
|
a70bf18770 | ||
|
|
2de83c44ab | ||
|
|
7b99f09810 | ||
|
|
6b4ba8bfad | ||
|
|
0c6cfc5020 | ||
|
|
abd9733e7f | ||
|
|
98c3ae5e76 | ||
|
|
bb5a657469 | ||
|
|
7797532350 | ||
|
|
c3a5106adc | ||
|
|
c5fd935dd0 | ||
|
|
ec375a19ae | ||
|
|
51e940617c | ||
|
|
58ec8bd437 | ||
|
|
a096395086 | ||
|
|
4bd08bd915 | ||
|
|
2c849cfa7a | ||
|
|
501d530d1d | ||
|
|
91fc4327f4 | ||
|
|
8d56c67079 | ||
|
|
e52d43458e | ||
|
|
9b125bf9b0 | ||
|
|
0716c65269 | ||
|
|
ba3ce4f1b5 | ||
|
|
07f72b0cdc | ||
|
|
bda19df87f | ||
|
|
5d82fae2b0 | ||
|
|
0813b87221 | ||
|
|
961ecfc720 | ||
|
|
81f30ef25a | ||
|
|
140b0d3df2 | ||
|
|
b3d69d7de4 | ||
|
|
8e65564fb8 | ||
|
|
06ce9bd4de | ||
|
|
274fc2d74f | ||
|
|
2f1a448afe | ||
|
|
99cab7c337 | ||
|
|
81f7548579 | ||
|
|
6ebd50bebc | ||
|
|
378ba51f4d | ||
|
|
63a890e85d | ||
|
|
bf4f9921e2 | ||
|
|
167ae65695 | ||
|
|
2affa7c9b8 | ||
|
|
785540e178 | ||
|
|
bcad4c0bc6 | ||
|
|
5af217fbf5 | ||
|
|
128aa2ef23 | ||
|
|
fce1186dd1 | ||
|
|
9a7b11f804 | ||
|
|
b068a06fa8 | ||
|
|
931a42e981 | ||
|
|
e0a20a6697 | ||
|
|
1ef4374899 | ||
|
|
3b7212740b | ||
|
|
4b80b8dc1f | ||
|
|
b7f24827e6 | ||
|
|
1c08a22881 | ||
|
|
8bd848519d | ||
|
|
e19f2aa76d | ||
|
|
4a99e2896f | ||
|
|
de3c83b0aa | ||
|
|
36bdb831be | ||
|
|
1809690915 | ||
|
|
e51b679380 | ||
|
|
10c26de7cb | ||
|
|
ca5ec8af0f | ||
|
|
d1d7b8ce55 | ||
|
|
77f8983307 | ||
|
|
ba415acd37 | ||
|
|
bcf13099ac | ||
|
|
eb2b34d71c | ||
|
|
d0b665f773 | ||
|
|
a1674b1ae5 | ||
|
|
af83681f6a | ||
|
|
bebacf7b20 | ||
|
|
6dc1fcbc3e | ||
|
|
b599ef4509 | ||
|
|
526b6a1119 | ||
|
|
88173db4ce | ||
|
|
e139b1ab22 | ||
|
|
6c1e0058c1 | ||
|
|
c96633eb83 | ||
|
|
91eb35a77b | ||
|
|
d749d59cad | ||
|
|
80396b4d30 | ||
|
|
64b93a009c | ||
|
|
2b32250504 | ||
|
|
9b5f863832 | ||
|
|
fd422d7446 | ||
|
|
5162b2748e | ||
|
|
56c684ec06 | ||
|
|
7e93b33407 | ||
|
|
7662235802 | ||
|
|
e41f9facc7 | ||
|
|
785b8ede11 | ||
|
|
78b198ad70 | ||
|
|
c2c0515991 | ||
|
|
b97fefdb8d | ||
|
|
840da6dd85 | ||
|
|
972d916126 | ||
|
|
e3ed065f5f | ||
|
|
760ebe6113 | ||
|
|
a329d3ad89 | ||
|
|
01f8561582 | ||
|
|
883ea5c996 | ||
|
|
99cf13ed9b | ||
|
|
91c7ef6801 | ||
|
|
84ef5705e7 | ||
|
|
cf2a0cf8c2 | ||
|
|
48c25c40e4 | ||
|
|
996d8ab954 | ||
|
|
fac2546a92 | ||
|
|
728ea6172a | ||
|
|
f59d225029 | ||
|
|
0b178a715f | ||
|
|
e06e5328c2 | ||
|
|
1c14cd0979 | ||
|
|
f9141f5ba2 | ||
|
|
48da5c976c | ||
|
|
fa38c81c08 | ||
|
|
8d5fe5270f | ||
|
|
0dc0d66549 | ||
|
|
f589fcc2d0 | ||
|
|
edd44a0993 | ||
|
|
2aae496742 | ||
|
|
6f72046f86 | ||
|
|
d4a9b446a6 | ||
|
|
95f571e9b9 | ||
|
|
e8aeae5c07 | ||
|
|
ddf6dc0343 | ||
|
|
36d55a9db7 | ||
|
|
7d41379ad5 | ||
|
|
63e928da96 | ||
|
|
5c983b64bc | ||
|
|
b2d36c0e68 | ||
|
|
6123a1620e | ||
|
|
5ae7c10a00 | ||
|
|
b5a6794381 | ||
|
|
6b575f836a |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -27,4 +27,7 @@ venv
|
||||
|
||||
# Pylint
|
||||
pylint-report.json
|
||||
.pylint.d/
|
||||
.pylint.d/
|
||||
|
||||
# AI
|
||||
.claude/
|
||||
|
||||
@@ -30,6 +30,8 @@
|
||||
|
||||
API文档:https://api.movie-pilot.org
|
||||
|
||||
MCP工具API文档:详见 [docs/mcp-api.md](docs/mcp-api.md)
|
||||
|
||||
本地运行需要 `Python 3.12`、`Node JS v20.12.1`
|
||||
|
||||
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Union
|
||||
import json
|
||||
import tiktoken
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
|
||||
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.memory import conversation_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -26,7 +30,9 @@ class AgentChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
"""
|
||||
MoviePilot AI智能体
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
@@ -39,12 +45,6 @@ class MoviePilotAgent:
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
@@ -56,9 +56,6 @@ class MoviePilotAgent:
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 会话存储
|
||||
self.session_store = self._initialize_session_store()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
@@ -66,48 +63,15 @@ class MoviePilotAgent:
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("未配置 LLM_API_KEY")
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
"""
|
||||
初始化LLM模型
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
"""
|
||||
初始化工具列表
|
||||
"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
@@ -119,43 +83,56 @@ class MoviePilotAgent:
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
"""
|
||||
初始化内存存储
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
if session_id not in self.session_store:
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_ai_message(AIMessage(
|
||||
"""
|
||||
获取会话历史
|
||||
"""
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)]
|
||||
))
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
self.session_store[session_id] = chat_history
|
||||
return self.session_store[session_id]
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(ToolMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=metadata.get("call_id", "unknown")
|
||||
))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
"""
|
||||
初始化提示词模板
|
||||
"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
@@ -169,13 +146,140 @@ class MoviePilotAgent:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
@staticmethod
|
||||
def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
|
||||
"""
|
||||
通用的Token计数器
|
||||
"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt=self.prompt
|
||||
# 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
# 基础开销 (每个消息大约 3 个 token)
|
||||
num_tokens += 3
|
||||
|
||||
# 1. 处理文本内容 (content)
|
||||
if isinstance(message.content, str):
|
||||
num_tokens += len(encoding.encode(message.content))
|
||||
elif isinstance(message.content, list):
|
||||
for part in message.content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
num_tokens += len(encoding.encode(part.get("text", "")))
|
||||
|
||||
# 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
|
||||
if getattr(message, "tool_calls", None):
|
||||
for tool_call in message.tool_calls:
|
||||
# 函数名
|
||||
num_tokens += len(encoding.encode(tool_call.get("name", "")))
|
||||
# 参数 (转为 JSON 估算)
|
||||
args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
|
||||
num_tokens += len(encoding.encode(args_str))
|
||||
# 额外的结构开销 (ID 等)
|
||||
num_tokens += 3
|
||||
|
||||
# 3. 处理角色权重
|
||||
num_tokens += 1
|
||||
|
||||
# 加上回复的起始 Token (大约 3 个 token)
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Token计数失败: {e}")
|
||||
# 发生错误时返回一个保守的估算值
|
||||
return len(str(messages)) // 4
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""
|
||||
创建Agent执行器
|
||||
"""
|
||||
try:
|
||||
# 消息裁剪器,防止上下文超出限制
|
||||
base_trimmer = trim_messages(
|
||||
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
|
||||
strategy="last",
|
||||
token_counter=self._token_counter,
|
||||
include_system=True,
|
||||
allow_partial=False,
|
||||
start_on="human",
|
||||
)
|
||||
|
||||
# 包装trimmer,在裁剪后验证工具调用的完整性
|
||||
def validated_trimmer(messages):
|
||||
# 如果输入是 PromptValue,转换为消息列表
|
||||
if hasattr(messages, "to_messages"):
|
||||
messages = messages.to_messages()
|
||||
trimmed = base_trimmer.invoke(messages)
|
||||
|
||||
# 二次校验:确保不出现 broken tool chains
|
||||
# 1. AIMessage with tool_calls 必须紧跟着对应的 ToolMessage
|
||||
# 2. ToolMessage 必须有对应的 AIMessage 前置
|
||||
safe_messages = []
|
||||
i = 0
|
||||
while i < len(trimmed):
|
||||
msg = trimmed[i]
|
||||
|
||||
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
||||
# 检查工具调用序列是否完整
|
||||
tool_calls = msg.tool_calls
|
||||
is_valid_sequence = True
|
||||
tool_results = []
|
||||
|
||||
# 向后查找对应的 ToolMessage
|
||||
temp_i = i + 1
|
||||
for tool_call in tool_calls:
|
||||
if temp_i >= len(trimmed):
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
next_msg = trimmed[temp_i]
|
||||
if isinstance(next_msg, ToolMessage) and next_msg.tool_call_id == tool_call.get("id"):
|
||||
tool_results.append(next_msg)
|
||||
temp_i += 1
|
||||
else:
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
if is_valid_sequence:
|
||||
# 序列完整,保留消息
|
||||
safe_messages.append(msg)
|
||||
safe_messages.extend(tool_results)
|
||||
i = temp_i # 跳过已处理的工具结果
|
||||
else:
|
||||
# 序列不完整,丢弃该 AIMessage(后续的孤立 ToolMessage 会在下一次循环被当做 orphaned 处理掉)
|
||||
logger.warning(f"移除无效的工具调用链: {len(tool_calls)} calls, incomplete results")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if isinstance(msg, ToolMessage):
|
||||
# 如果在这里遇到 ToolMessage,说明它没有被上面的逻辑消费,则是孤立的(或者顺序错乱)
|
||||
logger.warning("移除孤立的 ToolMessage")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 其他类型的消息直接保留
|
||||
safe_messages.append(msg)
|
||||
i += 1
|
||||
|
||||
if len(safe_messages) < len(messages):
|
||||
logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(safe_messages)}")
|
||||
return safe_messages
|
||||
|
||||
# 创建Agent执行链
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
||||
x["intermediate_steps"]
|
||||
)
|
||||
)
|
||||
| self.prompt
|
||||
| RunnableLambda(validated_trimmer)
|
||||
| self.llm.bind_tools(self.tools)
|
||||
| OpenAIToolsAgentOutputParser()
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
@@ -196,11 +300,83 @@ class MoviePilotAgent:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
async def _summarize_history(self):
|
||||
"""
|
||||
总结提炼之前的对话和工具执行情况,并把会话总结变成新的系统提示词取代之前的对话
|
||||
"""
|
||||
try:
|
||||
# 获取当前历史记录
|
||||
chat_history = self.get_session_history(self.session_id)
|
||||
messages = chat_history.messages
|
||||
if not messages:
|
||||
return
|
||||
|
||||
logger.info(f"会话 {self.session_id} 历史消息长度已超过 90%,开始总结并重置上下文...")
|
||||
|
||||
# 将消息转换为摘要所需的文本格式
|
||||
history_text = ""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_text += f"用户: {msg.content}\n"
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_text += f"智能体: {msg.content}\n"
|
||||
if getattr(msg, "tool_calls", None):
|
||||
for tool_call in msg.tool_calls:
|
||||
history_text += f"智能体调用工具: {tool_call.get('name')},参数: {tool_call.get('args')}\n"
|
||||
elif isinstance(msg, ToolMessage):
|
||||
history_text += f"工具响应: {msg.content}\n"
|
||||
elif isinstance(msg, SystemMessage):
|
||||
history_text += f"系统: {msg.content}\n"
|
||||
|
||||
# 摘要提示词
|
||||
summary_prompt = (
|
||||
"Please provide a comprehensive and highly informational summary of the preceding conversation and tool executions. "
|
||||
"Your goal is to condense the history while retaining all critical details for future reference. "
|
||||
"Ensure you include:\n"
|
||||
"1. User's core intents, specific requests, and any mentioned preferences.\n"
|
||||
"2. Names of movies, TV shows, or other key entities discussed.\n"
|
||||
"3. A concise log of tool calls made and their specific results/outcomes.\n"
|
||||
"4. The current status of any tasks and any pending actions.\n"
|
||||
"5. Any important context that would be necessary for the agent to continue the conversation seamlessly.\n"
|
||||
"The summary should be dense with information and serve as the primary context for the next stage of the interaction."
|
||||
)
|
||||
|
||||
# 调用 LLM 进行总结 (非流式)
|
||||
summary_llm = LLMHelper.get_llm(streaming=False)
|
||||
response = await summary_llm.ainvoke([
|
||||
SystemMessage(content=summary_prompt),
|
||||
HumanMessage(content=f"Here is the conversation history to summarize:\n{history_text}")
|
||||
])
|
||||
summary_content = str(response.content)
|
||||
|
||||
if not summary_content:
|
||||
logger.warning("总结生成失败,跳过重置逻辑。")
|
||||
return
|
||||
|
||||
# 清空原有的会话记录并插入新的系统总结
|
||||
await conversation_manager.clear_memory(self.session_id, self.user_id)
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="system",
|
||||
content=f"<history_summary>\n{summary_content}\n</history_summary>"
|
||||
)
|
||||
logger.info(f"会话 {self.session_id} 历史摘要替换完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"执行会话总结出错: {str(e)}")
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
try:
|
||||
# 检查上下文长度是否超过 90%
|
||||
history = self.get_session_history(self.session_id)
|
||||
if self._token_counter(history.messages) > settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.9:
|
||||
await self._summarize_history()
|
||||
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
@@ -209,27 +385,33 @@ class MoviePilotAgent:
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
await self._execute_agent(input_context)
|
||||
|
||||
result = await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
await self.send_agent_message(agent_message)
|
||||
if agent_message:
|
||||
# 发送回复
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
# 添加Agent回复到记忆
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
else:
|
||||
agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
return agent_message
|
||||
|
||||
@@ -241,7 +423,9 @@ class MoviePilotAgent:
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
"""
|
||||
执行LangChain Agent
|
||||
"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
@@ -268,13 +452,15 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": f"执行过程中发生错误: {str(e)}",
|
||||
"output": str(e),
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
@@ -287,26 +473,32 @@ class MoviePilotAgent:
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
if self.session_id in self.session_store:
|
||||
del self.session_store[self.session_id]
|
||||
"""
|
||||
清理智能体资源
|
||||
"""
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
"""
|
||||
AI智能体管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
await conversation_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
await conversation_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
@@ -314,7 +506,9 @@ class AgentManager:
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
@@ -325,7 +519,6 @@ class AgentManager:
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
@@ -342,12 +535,14 @@ class AgentManager:
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
await conversation_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
"""
|
||||
流式输出回调处理器
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
@@ -14,7 +16,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
"""
|
||||
获取当前消息内容,获取后清空
|
||||
"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
@@ -24,7 +28,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
"""
|
||||
处理新的token
|
||||
"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
|
||||
@@ -12,7 +12,9 @@ from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
"""
|
||||
对话记忆管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
@@ -23,7 +25,9 @@ class ConversationMemoryManager:
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
"""
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
@@ -33,7 +37,9 @@ class ConversationMemoryManager:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
"""
|
||||
关闭记忆管理器
|
||||
"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
@@ -45,47 +51,84 @@ class ConversationMemoryManager:
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
@staticmethod
|
||||
def _get_memory_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算内存Key
|
||||
"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
# 尝试从Redis加载
|
||||
@staticmethod
|
||||
def _get_redis_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算Redis Key
|
||||
"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
def _get_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
获取内存中的记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
return self.memory_cache.get(cache_key)
|
||||
|
||||
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
从Redis获取记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""
|
||||
获取会话记忆
|
||||
"""
|
||||
# 首先检查缓存
|
||||
conversion = self._get_memory(session_id, user_id)
|
||||
if conversion:
|
||||
return conversion
|
||||
|
||||
# 尝试从Redis加载
|
||||
memory = await self._get_redis(session_id, user_id)
|
||||
if memory:
|
||||
# 加载到内存缓存
|
||||
self._save_memory(memory)
|
||||
return memory
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
设置会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
"""
|
||||
列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
@@ -138,7 +181,7 @@ class ConversationMemoryManager:
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
async def add_conversation(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
@@ -146,8 +189,10 @@ class ConversationMemoryManager:
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
添加消息到记忆
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
@@ -167,7 +212,7 @@ class ConversationMemoryManager:
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
@@ -176,19 +221,18 @@ class ConversationMemoryManager:
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
|
||||
return messages
|
||||
return memory.messages[:-1]
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
@@ -197,8 +241,10 @@ class ConversationMemoryManager:
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取最近的消息
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
@@ -207,36 +253,41 @@ class ConversationMemoryManager:
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话上下文
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
"""
|
||||
清空会话记忆
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
def _save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
保存记忆到内存
|
||||
"""
|
||||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
async def _save_redis(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到Redis
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
@@ -247,8 +298,22 @@ class ConversationMemoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _save_conversation(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
self._save_memory(memory)
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
await self._save_redis(memory)
|
||||
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
"""
|
||||
清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
@@ -278,3 +343,5 @@ class ConversationMemoryManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
|
||||
conversation_manager = ConversationMemoryManager()
|
||||
|
||||
@@ -1,70 +1,72 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
You are an AI media assistant powered by MoviePilot, specialized in managing home media ecosystems. Your expertise covers searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
Core Capabilities:
|
||||
1. Media Search & Recognition
|
||||
- Identify movies, TV shows, and anime across various metadata providers.
|
||||
- Recognize media info from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management
|
||||
- Create complex rules for automated downloading of new episodes.
|
||||
- Monitor trending movies/shows for automated suggestions.
|
||||
3. Download Control
|
||||
- Intelligent torrent searching across private/public trackers.
|
||||
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
|
||||
4. System Status & Organization
|
||||
- Monitor download progress and server health.
|
||||
- Manage file transfers, renaming, and library cleanup.
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
<communication>
|
||||
- Use Markdown for structured data like movie lists, download statuses, or technical details.
|
||||
- Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary.
|
||||
- ALWAYS use backticks for media titles (e.g., `Interstellar`), file paths, or specific parameters.
|
||||
- Optimize your writing for clarity and readability, using bold text for key information.
|
||||
- Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions.
|
||||
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
|
||||
|
||||
## Working Principles
|
||||
Important Notes:
|
||||
- User-Centric: Your tone should be helpful, professional, and media-savvy.
|
||||
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
|
||||
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
|
||||
</communication>
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
<status_update_spec>
|
||||
Definition: Provide a brief progress narrative (1-3 sentences) explaining what you have searched, what you found, and what you are about to execute.
|
||||
- **Immediate Execution**: If you state an intention to perform an action (e.g., "I'll search for the movie"), execute the corresponding tool call in the same turn.
|
||||
- Use natural tenses: "I've found...", "I'm checking...", "I will now add...".
|
||||
- Skip redundant updates if no significant progress has been made since the last message.
|
||||
</status_update_spec>
|
||||
|
||||
## Common Operation Workflows
|
||||
<summary_spec>
|
||||
At the end of your session/turn, provide a concise summary of your actions.
|
||||
- Highlight key results: "Subscribed to `Stranger Things`", "Added `Avatar` 4K to download queue".
|
||||
- Use bullet points for multiple actions.
|
||||
- Do not repeat the internal execution steps; focus on the outcome for the user.
|
||||
</summary_spec>
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
<flow>
|
||||
1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?).
|
||||
3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
|
||||
4. Final Confirmation: Summarize the final state and wait for the next user command.
|
||||
</flow>
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
<tool_calling_strategy>
|
||||
- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
|
||||
- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
|
||||
- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
|
||||
</tool_calling_strategy>
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
<media_management_rules>
|
||||
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
|
||||
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
|
||||
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
|
||||
4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
|
||||
</media_management_rules>
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
@@ -1,13 +1,15 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
"""
|
||||
提示词管理器
|
||||
"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
@@ -17,22 +19,20 @@ class PromptManager:
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
"""
|
||||
加载指定的提示词
|
||||
"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
@@ -46,73 +46,43 @@ class PromptManager:
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
|
||||
# 识别渠道
|
||||
msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None
|
||||
if msg_channel:
|
||||
# 获取渠道能力说明
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
base_prompt = base_prompt.replace(
|
||||
"{markdown_spec}",
|
||||
self._generate_formatting_instructions(caps)
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
根据渠道能力动态生成格式指令
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
instructions = []
|
||||
if ChannelCapability.RICH_TEXT not in caps.capabilities:
|
||||
instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.")
|
||||
instructions.append(
|
||||
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).")
|
||||
instructions.append(
|
||||
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.")
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
"""
|
||||
清空缓存
|
||||
"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.agent import StreamingCallbackHandler, conversation_manager
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
@@ -15,7 +17,9 @@ class ToolChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
"""
|
||||
MoviePilot专用工具基类
|
||||
"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
@@ -33,23 +37,79 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送运行工具前的消息
|
||||
"""
|
||||
异步运行工具
|
||||
"""
|
||||
# 获取工具调用前的agent消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
|
||||
# 生成唯一的工具调用ID
|
||||
call_id = f"call_{str(uuid.uuid4())[:16]}"
|
||||
|
||||
# 记忆工具调用
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
"parameters": kwargs
|
||||
}
|
||||
)
|
||||
|
||||
# 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
|
||||
# 合并agent消息和工具执行消息,一起发送
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
return await self.run(**kwargs)
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
# 发送合并后的消息
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message, title="MoviePilot助手")
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
|
||||
# 执行工具,捕获异常确保结果总是被存储到记忆中
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
except Exception as e:
|
||||
# 记录异常详情
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f'Tool {self.name} execution failed: {e}', exc_info=True)
|
||||
result = error_message
|
||||
|
||||
# 记忆工具调用结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result,
|
||||
metadata={
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
@@ -71,17 +131,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
"""
|
||||
设置消息属性
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
"""
|
||||
设置回调处理器
|
||||
"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
|
||||
@@ -1,28 +1,33 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool
|
||||
from app.agent.tools.impl.search_subscribe import SearchSubscribeTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.query_download_tasks import QueryDownloadTasksTool
|
||||
from app.agent.tools.impl.query_library_exists import QueryLibraryExistsTool
|
||||
from app.agent.tools.impl.query_library_latest import QueryLibraryLatestTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.update_site import UpdateSiteTool
|
||||
from app.agent.tools.impl.query_site_userdata import QuerySiteUserdataTool
|
||||
from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
||||
from app.agent.tools.impl.query_rule_groups import QueryRuleGroupsTool
|
||||
from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool
|
||||
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.search_person import SearchPersonTool
|
||||
from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool
|
||||
from app.agent.tools.impl.recognize_media import RecognizeMediaTool
|
||||
from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
|
||||
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
|
||||
from app.agent.tools.impl.query_media_detail import QueryMediaDetailTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
@@ -30,39 +35,50 @@ from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.query_directories import QueryDirectoriesTool
|
||||
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool
|
||||
from app.agent.tools.impl.transfer_file import TransferFileTool
|
||||
from app.agent.tools.impl.execute_command import ExecuteCommandTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
"""
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
SearchPersonTool,
|
||||
SearchPersonCreditsTool,
|
||||
RecognizeMediaTool,
|
||||
ScrapeMetadataTool,
|
||||
QueryEpisodeScheduleTool,
|
||||
QueryMediaDetailTool,
|
||||
AddSubscribeTool,
|
||||
UpdateSubscribeTool,
|
||||
SearchSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
SearchWebTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QuerySubscribeSharesTool,
|
||||
QueryPopularSubscribesTool,
|
||||
QueryRuleGroupsTool,
|
||||
QuerySubscribeHistoryTool,
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadsTool,
|
||||
QueryDownloadTasksTool,
|
||||
DeleteDownloadTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
@@ -71,8 +87,9 @@ class MoviePilotToolFactory:
|
||||
TestSiteTool,
|
||||
UpdateSiteCookieTool,
|
||||
GetRecommendationsTool,
|
||||
QueryMediaLibraryTool,
|
||||
QueryDirectoriesTool,
|
||||
QueryLibraryExistsTool,
|
||||
QueryLibraryLatestTool,
|
||||
QueryDirectorySettingsTool,
|
||||
ListDirectoryTool,
|
||||
QueryTransferHistoryTool,
|
||||
TransferFileTool,
|
||||
@@ -80,7 +97,8 @@ class MoviePilotToolFactory:
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool
|
||||
RunWorkflowTool,
|
||||
ExecuteCommandTool
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
|
||||
@@ -25,7 +25,7 @@ class AddDownloadInput(BaseModel):
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
|
||||
description="Directory path where the downloaded files should be saved. Using `<storage>:<path>` for remote storage. e.g. rclone:/MP, smb:/server/share/Movies. (optional, uses default path if not specified)")
|
||||
labels: Optional[str] = Field(None,
|
||||
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -31,6 +31,10 @@ class AddSubscribeInput(BaseModel):
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
@@ -59,11 +63,13 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, **kwargs) -> str:
|
||||
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, effect={effect}")
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
|
||||
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
@@ -87,6 +93,10 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
subscribe_kwargs['resolution'] = resolution
|
||||
if effect:
|
||||
subscribe_kwargs['effect'] = effect
|
||||
if filter_groups:
|
||||
subscribe_kwargs['filter_groups'] = filter_groups
|
||||
if sites:
|
||||
subscribe_kwargs['sites'] = sites
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
@@ -98,6 +108,9 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
**subscribe_kwargs
|
||||
)
|
||||
if sid:
|
||||
if message and "已存在" in message:
|
||||
return f"订阅已存在:{title} ({year})。如需修改参数请先删除旧订阅。"
|
||||
|
||||
result_msg = f"成功添加订阅:{title} ({year})"
|
||||
if subscribe_kwargs:
|
||||
params = []
|
||||
@@ -111,6 +124,10 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
params.append(f"分辨率过滤: {resolution}")
|
||||
if effect:
|
||||
params.append(f"特效过滤: {effect}")
|
||||
if filter_groups:
|
||||
params.append(f"规则组: {', '.join(filter_groups)}")
|
||||
if sites:
|
||||
params.append(f"站点: {', '.join(map(str, sites))}")
|
||||
if params:
|
||||
result_msg += f"\n配置参数: {', '.join(params)}"
|
||||
return result_msg
|
||||
|
||||
81
app/agent/tools/impl/execute_command.py
Normal file
81
app/agent/tools/impl/execute_command.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""执行Shell命令工具"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this command is being executed")
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)")
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
name: str = "execute_command"
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行系统命令: {command}"
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}")
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode('utf-8', errors='replace').strip()
|
||||
stderr_str = stderr.decode('utf-8', errors='replace').strip()
|
||||
exit_code = process.returncode
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
if stdout_str:
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时处理
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return f"命令执行超时 (限制: {timeout}秒)"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return f"执行命令时发生错误: {str(e)}"
|
||||
@@ -24,7 +24,7 @@ class ListDirectoryInput(BaseModel):
|
||||
|
||||
class ListDirectoryTool(MoviePilotTool):
|
||||
name: str = "list_directory"
|
||||
description: str = "List contents of a file system directory. Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items."
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directories' to query directory configuration settings."
|
||||
args_schema: Type[BaseModel] = ListDirectoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDirectoriesInput(BaseModel):
|
||||
class QueryDirectorySettingsInput(BaseModel):
|
||||
"""查询系统目录设置工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
directory_type: Optional[str] = Field("all",
|
||||
@@ -21,10 +21,10 @@ class QueryDirectoriesInput(BaseModel):
|
||||
description="Filter directories by name (partial match, optional)")
|
||||
|
||||
|
||||
class QueryDirectoriesTool(MoviePilotTool):
|
||||
name: str = "query_directories"
|
||||
description: str = "Query system directory configuration and list all configured directories. Shows download directories, media library directories, storage settings, transfer modes, and other directory-related configurations."
|
||||
args_schema: Type[BaseModel] = QueryDirectoriesInput
|
||||
class QueryDirectorySettingsTool(MoviePilotTool):
|
||||
name: str = "query_directory_settings"
|
||||
description: str = "Query system directory configuration settings (NOT file listings). Returns configured directory paths, storage types, transfer modes, and other directory-related settings. Use 'list_directory' to list actual files and folders in a directory."
|
||||
args_schema: Type[BaseModel] = QueryDirectorySettingsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
@@ -1,7 +1,7 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -9,9 +9,11 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus
|
||||
|
||||
|
||||
class QueryDownloadsInput(BaseModel):
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
@@ -22,10 +24,32 @@ class QueryDownloadsInput(BaseModel):
|
||||
title: Optional[str] = Field(None, description="Query download tasks by title/name (optional, supports partial match, searches all tasks if provided)")
|
||||
|
||||
|
||||
class QueryDownloadsTool(MoviePilotTool):
|
||||
name: str = "query_downloads"
|
||||
class QueryDownloadTasksTool(MoviePilotTool):
|
||||
name: str = "query_download_tasks"
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadsInput
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
@staticmethod
|
||||
def _get_all_torrents(download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
|
||||
"""
|
||||
查询所有状态的任务(包括下载中和已完成的任务)
|
||||
"""
|
||||
all_torrents = []
|
||||
# 查询正在下载的任务
|
||||
downloading_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.DOWNLOADING
|
||||
) or []
|
||||
all_torrents.extend(downloading_torrents)
|
||||
|
||||
# 查询已完成的任务(可转移状态)
|
||||
transfer_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.TRANSFER
|
||||
) or []
|
||||
all_torrents.extend(transfer_torrents)
|
||||
|
||||
return all_torrents
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
@@ -60,7 +84,7 @@ class QueryDownloadsTool(MoviePilotTool):
|
||||
|
||||
# 如果提供了hash,直接查询该hash的任务(不限制状态)
|
||||
if hash:
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash])
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash]) or []
|
||||
if not torrents:
|
||||
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
|
||||
# 转换为DownloadingTorrent格式
|
||||
@@ -84,14 +108,25 @@ class QueryDownloadsTool(MoviePilotTool):
|
||||
elif title:
|
||||
# 如果提供了title,查询所有任务并搜索匹配的标题
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
title_lower = title.lower()
|
||||
for torrent in all_torrents:
|
||||
# 检查标题或名称是否匹配
|
||||
if (title.lower() in (torrent.title or "").lower()) or \
|
||||
(title.lower() in (torrent.name or "").lower()):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
|
||||
# 检查标题或名称是否匹配(包括下载历史中的标题)
|
||||
matched = False
|
||||
# 检查torrent的title和name字段
|
||||
if (title_lower in (torrent.title or "").lower()) or \
|
||||
(title_lower in (torrent.name or "").lower()):
|
||||
matched = True
|
||||
# 检查下载历史中的标题
|
||||
if history and history.title:
|
||||
if title_lower in history.title.lower():
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
@@ -110,7 +145,7 @@ class QueryDownloadsTool(MoviePilotTool):
|
||||
# 根据status决定查询方式
|
||||
if status == "downloading":
|
||||
# 如果status为下载中,使用downloading方法
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
downloads = download_chain.downloading(name=downloader) or []
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
@@ -119,7 +154,7 @@ class QueryDownloadsTool(MoviePilotTool):
|
||||
else:
|
||||
# 其他状态(completed、paused、all),使用list_torrents查询所有任务
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
if downloader and torrent.downloader != downloader:
|
||||
139
app/agent/tools/impl/query_library_exists.py
Normal file
139
app/agent/tools/impl/query_library_exists.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class QueryLibraryExistsInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryLibraryExistsTool(MoviePilotTool):
|
||||
name: str = "query_library_exists"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryLibraryExistsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
title = kwargs.get("title")
|
||||
year = kwargs.get("year")
|
||||
|
||||
parts = ["正在查询媒体库"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if year:
|
||||
parts.append(f"年份: {year}")
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
if not title:
|
||||
return "请提供媒体标题进行查询"
|
||||
|
||||
media_chain = MediaServerChain()
|
||||
|
||||
# 1. 识别媒体信息(获取 TMDB ID 和各季的总集数等元数据)
|
||||
meta = MetaBase(title=title)
|
||||
if year:
|
||||
meta.year = str(year)
|
||||
if media_type == "电影":
|
||||
meta.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
meta.type = MediaType.TV
|
||||
|
||||
# 使用识别方法补充信息
|
||||
recognize_info = media_chain.recognize_media(meta=meta)
|
||||
if recognize_info:
|
||||
mediainfo = recognize_info
|
||||
else:
|
||||
# 识别失败,创建基本信息的 MediaInfo
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
|
||||
# 2. 调用媒体服务器接口实时查询存在信息
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
if not existsinfo:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 3. 如果找到了,获取详细信息并组装结果
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
|
||||
# 对于电视剧,补充已存在的季集详情及进度统计
|
||||
if existsinfo.type == MediaType.TV:
|
||||
# 注入已存在集信息 (Dict[int, list])
|
||||
item_dict["seasoninfo"] = existsinfo.seasons
|
||||
|
||||
# 统计库中已存在的季集总数
|
||||
if existsinfo.seasons:
|
||||
item_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
item_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
|
||||
# 如果识别到了元数据,补充总计对比和进度概览
|
||||
if mediainfo.seasons:
|
||||
item_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
# 进度概览,例如 "Season 1": "3/12"
|
||||
item_dict["seasons_progress"] = {
|
||||
f"第{s}季": f"{len(existsinfo.seasons.get(s, []))}/{len(mediainfo.seasons.get(s, []))} 集"
|
||||
for s in mediainfo.seasons.keys() if (s in existsinfo.seasons or s > 0)
|
||||
}
|
||||
|
||||
result_items.append(item_dict)
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有获取到 iteminfo,返回基本信息
|
||||
result_dict = {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": existsinfo.type.value if existsinfo.type else None,
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
}
|
||||
if existsinfo.type == MediaType.TV and existsinfo.seasons:
|
||||
result_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
result_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
if mediainfo.seasons:
|
||||
result_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
|
||||
86
app/agent/tools/impl/query_library_latest.py
Normal file
86
app/agent/tools/impl/query_library_latest.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""查询媒体服务器最近入库影片工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryLibraryLatestInput(BaseModel):
|
||||
"""查询媒体服务器最近入库影片工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
server: Optional[str] = Field(None, description="Media server name (optional, if not specified queries all enabled media servers)")
|
||||
count: Optional[int] = Field(20, description="Number of items to return (default: 20)")
|
||||
|
||||
|
||||
class QueryLibraryLatestTool(MoviePilotTool):
|
||||
name: str = "query_library_latest"
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata."
|
||||
args_schema: Type[BaseModel] = QueryLibraryLatestInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
server = kwargs.get("server")
|
||||
count = kwargs.get("count", 20)
|
||||
|
||||
parts = ["正在查询媒体服务器最近入库影片"]
|
||||
|
||||
if server:
|
||||
parts.append(f"服务器: {server}")
|
||||
else:
|
||||
parts.append("所有服务器")
|
||||
|
||||
parts.append(f"数量: {count}条")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(self, server: Optional[str] = None, count: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: server={server}, count={count}")
|
||||
try:
|
||||
media_chain = MediaServerChain()
|
||||
results = []
|
||||
|
||||
# 如果没有指定服务器,获取所有启用的媒体服务器
|
||||
if not server:
|
||||
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
|
||||
enabled_servers = [ms.name for ms in mediaservers if ms.enabled]
|
||||
|
||||
if not enabled_servers:
|
||||
return "未找到启用的媒体服务器"
|
||||
|
||||
# 遍历所有启用的服务器
|
||||
for server_name in enabled_servers:
|
||||
latest_items = media_chain.latest(server=server_name, count=count, username=self._username)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server_name
|
||||
results.append(item_dict)
|
||||
else:
|
||||
# 查询指定服务器
|
||||
latest_items = media_chain.latest(server=server, count=count, username=self._username)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server
|
||||
results.append(item_dict)
|
||||
|
||||
if not results:
|
||||
server_info = f"服务器 {server}" if server else "所有服务器"
|
||||
return f"未找到 {server_info} 的最近入库影片"
|
||||
|
||||
# 限制返回数量,避免结果过多
|
||||
if len(results) > count:
|
||||
results = results[:count]
|
||||
|
||||
return json.dumps(results, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体服务器最近入库影片失败: {e}", exc_info=True)
|
||||
return f"查询媒体服务器最近入库影片时发生错误: {str(e)}"
|
||||
|
||||
120
app/agent/tools/impl/query_media_detail.py
Normal file
120
app/agent/tools/impl/query_media_detail.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""查询媒体详情工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class QueryMediaDetailInput(BaseModel):
|
||||
"""查询媒体详情工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the media (movie or TV series)")
|
||||
media_type: str = Field(..., description="Media type: 'movie' or 'tv'")
|
||||
|
||||
|
||||
class QueryMediaDetailTool(MoviePilotTool):
|
||||
name: str = "query_media_detail"
|
||||
description: str = "Query detailed media information from TMDB by ID and media_type. IMPORTANT: Convert search results type: '电影'→'movie', '电视剧'→'tv'. Returns core metadata including title, year, overview, status, genres, directors, actors, and season count for TV series."
|
||||
args_schema: Type[BaseModel] = QueryMediaDetailInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
|
||||
|
||||
async def run(self, tmdb_id: int, media_type: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, media_type={media_type}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
|
||||
mtype = None
|
||||
if media_type:
|
||||
if media_type.lower() == 'movie':
|
||||
mtype = MediaType.MOVIE
|
||||
elif media_type.lower() == 'tv':
|
||||
mtype = MediaType.TV
|
||||
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=mtype)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 精简 genres - 只保留名称
|
||||
genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")]
|
||||
|
||||
# 精简 directors - 只保留姓名和职位
|
||||
directors = [
|
||||
{
|
||||
"name": d.get("name"),
|
||||
"job": d.get("job")
|
||||
}
|
||||
for d in (mediainfo.directors or [])
|
||||
if d.get("name")
|
||||
]
|
||||
|
||||
# 精简 actors - 只保留姓名和角色
|
||||
actors = [
|
||||
{
|
||||
"name": a.get("name"),
|
||||
"character": a.get("character")
|
||||
}
|
||||
for a in (mediainfo.actors or [])
|
||||
if a.get("name")
|
||||
]
|
||||
|
||||
# 构建基础媒体详情信息
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"overview": mediainfo.overview,
|
||||
"status": mediainfo.status,
|
||||
"genres": genres,
|
||||
"directors": directors,
|
||||
"actors": actors
|
||||
}
|
||||
|
||||
# 如果是电视剧,添加电视剧特有信息
|
||||
if mediainfo.type == MediaType.TV:
|
||||
# 精简 season_info - 只保留基础摘要
|
||||
season_info = [
|
||||
{
|
||||
"season_number": s.get("season_number"),
|
||||
"name": s.get("name"),
|
||||
"episode_count": s.get("episode_count"),
|
||||
"air_date": s.get("air_date")
|
||||
}
|
||||
for s in (mediainfo.season_info or [])
|
||||
if s.get("season_number") is not None
|
||||
]
|
||||
|
||||
result.update({
|
||||
"number_of_seasons": mediainfo.number_of_seasons,
|
||||
"number_of_episodes": mediainfo.number_of_episodes,
|
||||
"first_air_date": mediainfo.first_air_date,
|
||||
"last_air_date": mediainfo.last_air_date,
|
||||
"season_info": season_info
|
||||
})
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询媒体详情失败: {str(e)}"
|
||||
logger.error(f"查询媒体详情失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"tmdb_id": tmdb_id
|
||||
}, ensure_ascii=False)
|
||||
@@ -1,58 +0,0 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaServerItem
|
||||
|
||||
|
||||
class QueryMediaLibraryInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryMediaLibraryTool(MoviePilotTool):
|
||||
name: str = "query_media_library"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryMediaLibraryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
title = kwargs.get("title")
|
||||
year = kwargs.get("year")
|
||||
|
||||
parts = ["正在查询媒体库"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if year:
|
||||
parts.append(f"年份: {year}")
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
media_server_oper = MediaServerOper()
|
||||
filtered_medias: List[MediaServerItem] = await media_server_oper.async_exists(title=title, year=year, mtype=media_type)
|
||||
if filtered_medias:
|
||||
return json.dumps([m.to_dict() for m in filtered_medias])
|
||||
return "媒体库中未找到相关媒体"
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
65
app/agent/tools/impl/query_rule_groups.py
Normal file
65
app/agent/tools/impl/query_rule_groups.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""查询规则组工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryRuleGroupsInput(BaseModel):
|
||||
"""查询规则组工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryRuleGroupsTool(MoviePilotTool):
|
||||
name: str = "query_rule_groups"
|
||||
description: str = "Query all filter rule groups available in the system. Rule groups are used to filter torrents when searching or subscribing. Returns rule group names, media types, and categories, but excludes rule_string to keep results concise."
|
||||
args_schema: Type[BaseModel] = QueryRuleGroupsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
return "正在查询所有规则组"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
rule_helper = RuleHelper()
|
||||
rule_groups = rule_helper.get_rule_groups()
|
||||
|
||||
if not rule_groups:
|
||||
return json.dumps({
|
||||
"message": "未找到任何规则组",
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
# 精简字段,过滤掉 rule_string 避免结果过大
|
||||
simplified_groups = []
|
||||
for group in rule_groups:
|
||||
simplified = {
|
||||
"name": group.name,
|
||||
"media_type": group.media_type,
|
||||
"category": group.category
|
||||
}
|
||||
simplified_groups.append(simplified)
|
||||
|
||||
result = {
|
||||
"message": f"找到 {len(simplified_groups)} 个规则组",
|
||||
"rule_groups": simplified_groups
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询规则组失败: {str(e)}"
|
||||
logger.error(f"查询规则组失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
|
||||
@@ -21,7 +21,7 @@ class QuerySitesInput(BaseModel):
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration."
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -14,7 +14,7 @@ class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types")
|
||||
|
||||
@@ -33,7 +33,7 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
|
||||
# 根据状态过滤条件生成提示
|
||||
if status != "all":
|
||||
status_map = {"R": "已启用", "P": "已禁用"}
|
||||
status_map = {"R": "已启用", "S": "已暂停"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
# 根据媒体类型过滤条件生成提示
|
||||
|
||||
@@ -22,7 +22,7 @@ class RecognizeMediaInput(BaseModel):
|
||||
|
||||
class RecognizeMediaTool(MoviePilotTool):
|
||||
name: str = "recognize_media"
|
||||
description: str = "Recognize media information from torrent titles or file paths. Supports two modes: 1) Recognize from torrent title and optional subtitle, 2) Recognize from file path. Returns detailed media information including title, year, type, TMDB ID, overview, and other metadata."
|
||||
description: str = "Extract/identify media information from torrent titles or file paths (NOT database search). Supports two modes: 1) Extract from torrent title and optional subtitle, 2) Extract from file path. Returns detailed media information. Use 'search_media' to search TMDB database, or 'scrape_metadata' to generate metadata files for existing files."
|
||||
args_schema: Type[BaseModel] = RecognizeMediaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""刮削媒体元数据工具"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
@@ -9,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.config import global_vars
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem
|
||||
@@ -17,14 +17,17 @@ from app.schemas import FileItem
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(..., description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False, description="Whether to overwrite existing metadata files (default: False)")
|
||||
path: str = Field(...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False,
|
||||
description="Whether to overwrite existing metadata files (default: False)")
|
||||
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
name: str = "scrape_metadata"
|
||||
description: str = "Scrape media metadata (NFO files, posters, backgrounds, etc.) for a file or directory. Automatically recognizes media information from the file path and generates metadata files. Supports both local and remote storage."
|
||||
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
|
||||
args_schema: Type[BaseModel] = ScrapeMetadataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -32,19 +35,19 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
path = kwargs.get("path", "")
|
||||
storage = kwargs.get("storage", "local")
|
||||
overwrite = kwargs.get("overwrite", False)
|
||||
|
||||
|
||||
message = f"正在刮削媒体元数据: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
if overwrite:
|
||||
message += " [覆盖模式]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
|
||||
|
||||
|
||||
try:
|
||||
# 验证路径
|
||||
if not path:
|
||||
@@ -52,14 +55,14 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"success": False,
|
||||
"message": "刮削路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 创建 FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage,
|
||||
path=path,
|
||||
type="file" if Path(path).suffix else "dir"
|
||||
)
|
||||
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
@@ -68,23 +71,22 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"success": False,
|
||||
"message": f"刮削路径不存在: {path}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
scrape_path = Path(path)
|
||||
meta = MetaInfoPath(scrape_path)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 在线程池中执行同步的刮削操作
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: media_chain.scrape_metadata(
|
||||
fileitem=fileitem,
|
||||
@@ -93,7 +95,7 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
overwrite=overwrite
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
@@ -106,7 +108,7 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"season": mediainfo.season
|
||||
}
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"刮削媒体元数据失败: {str(e)}"
|
||||
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
|
||||
@@ -115,4 +117,3 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"message": error_message,
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class SearchMediaInput(BaseModel):
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
|
||||
description: str = "Search TMDB database for media resources (movies, TV shows, anime, etc.) by title, year, type, and other criteria. Returns detailed media information from TMDB. Use 'recognize_media' to extract info from torrent titles/file paths, or 'scrape_metadata' to generate metadata files."
|
||||
args_schema: Type[BaseModel] = SearchMediaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -51,17 +51,8 @@ class SearchMediaTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 构建搜索标题
|
||||
search_title = title
|
||||
if year:
|
||||
search_title = f"{title} {year}"
|
||||
if media_type:
|
||||
search_title = f"{search_title} {media_type}"
|
||||
if season:
|
||||
search_title = f"{search_title} S{season:02d}"
|
||||
|
||||
# 使用 MediaChain.search 方法
|
||||
meta, results = await media_chain.async_search(title=search_title)
|
||||
meta, results = await media_chain.async_search(title=title)
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
@@ -72,7 +63,7 @@ class SearchMediaTool(MoviePilotTool):
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if season and result.season != season:
|
||||
if season is not None and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
|
||||
83
app/agent/tools/impl/search_person.py
Normal file
83
app/agent/tools/impl/search_person.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""搜索人物工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchPersonInput(BaseModel):
|
||||
"""搜索人物工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
name: str = Field(..., description="The name of the person to search for (e.g., 'Tom Hanks', '周杰伦')")
|
||||
|
||||
|
||||
class SearchPersonTool(MoviePilotTool):
|
||||
name: str = "search_person"
|
||||
description: str = "Search for person information including actors, directors, etc. Supports searching by name. Returns detailed person information from TMDB, Douban, or Bangumi database."
|
||||
args_schema: Type[BaseModel] = SearchPersonInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
name = kwargs.get("name", "")
|
||||
return f"正在搜索人物: {name}"
|
||||
|
||||
async def run(self, name: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: name={name}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 使用 MediaChain.async_search_persons 方法搜索人物
|
||||
persons = await media_chain.async_search_persons(name=name)
|
||||
|
||||
if persons:
|
||||
# 限制最多30条结果
|
||||
total_count = len(persons)
|
||||
limited_persons = persons[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for person in limited_persons:
|
||||
simplified = {
|
||||
"name": person.name,
|
||||
"id": person.id,
|
||||
"source": person.source,
|
||||
"profile_path": person.profile_path,
|
||||
"original_name": person.original_name,
|
||||
"known_for_department": person.known_for_department,
|
||||
"popularity": person.popularity,
|
||||
"biography": person.biography[:200] + "..." if person.biography and len(person.biography) > 200 else person.biography,
|
||||
"birthday": person.birthday,
|
||||
"deathday": person.deathday,
|
||||
"place_of_birth": person.place_of_birth,
|
||||
"gender": person.gender,
|
||||
"imdb_id": person.imdb_id,
|
||||
"also_known_as": person.also_known_as[:5] if person.also_known_as else [], # 限制别名数量
|
||||
}
|
||||
# 添加豆瓣特有字段
|
||||
if person.source == "douban":
|
||||
simplified["url"] = person.url
|
||||
simplified["avatar"] = person.avatar
|
||||
simplified["latin_name"] = person.latin_name
|
||||
simplified["roles"] = person.roles[:5] if person.roles else [] # 限制角色数量
|
||||
# 添加Bangumi特有字段
|
||||
if person.source == "bangumi":
|
||||
simplified["career"] = person.career
|
||||
simplified["relation"] = person.relation
|
||||
|
||||
simplified_results.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关人物信息: {name}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索人物失败: {str(e)}"
|
||||
logger.error(f"搜索人物失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
85
app/agent/tools/impl/search_person_credits.py
Normal file
85
app/agent/tools/impl/search_person_credits.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""搜索演员参演作品工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchPersonCreditsInput(BaseModel):
|
||||
"""搜索演员参演作品工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
person_id: int = Field(..., description="The ID of the person/actor to search for credits (e.g., 31 for Tom Hanks in TMDB)")
|
||||
source: str = Field(..., description="The data source: 'tmdb' for TheMovieDB, 'douban' for Douban, 'bangumi' for Bangumi")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
|
||||
|
||||
class SearchPersonCreditsTool(MoviePilotTool):
|
||||
name: str = "search_person_credits"
|
||||
description: str = "Search for films and TV shows that a person/actor has appeared in (filmography). Supports searching by person ID from TMDB, Douban, or Bangumi database. Returns a list of media works the person has participated in."
|
||||
args_schema: Type[BaseModel] = SearchPersonCreditsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
person_id = kwargs.get("person_id", "")
|
||||
source = kwargs.get("source", "")
|
||||
return f"正在搜索人物参演作品: {source} ID {person_id}"
|
||||
|
||||
async def run(self, person_id: int, source: str, page: Optional[int] = 1, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: person_id={person_id}, source={source}, page={page}")
|
||||
|
||||
try:
|
||||
# 根据source选择相应的chain
|
||||
if source.lower() == "tmdb":
|
||||
tmdb_chain = TmdbChain()
|
||||
medias = await tmdb_chain.async_person_credits(person_id=person_id, page=page)
|
||||
elif source.lower() == "douban":
|
||||
douban_chain = DoubanChain()
|
||||
medias = await douban_chain.async_person_credits(person_id=person_id, page=page)
|
||||
elif source.lower() == "bangumi":
|
||||
bangumi_chain = BangumiChain()
|
||||
medias = await bangumi_chain.async_person_credits(person_id=person_id)
|
||||
else:
|
||||
return f"不支持的数据源: {source}。支持的数据源: tmdb, douban, bangumi"
|
||||
|
||||
if medias:
|
||||
# 限制最多30条结果
|
||||
total_count = len(medias)
|
||||
limited_medias = medias[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for media in limited_medias:
|
||||
simplified = {
|
||||
"title": media.title,
|
||||
"en_title": media.en_title,
|
||||
"year": media.year,
|
||||
"type": media.type.value if media.type else None,
|
||||
"season": media.season,
|
||||
"tmdb_id": media.tmdb_id,
|
||||
"imdb_id": media.imdb_id,
|
||||
"douban_id": media.douban_id,
|
||||
"overview": media.overview[:200] + "..." if media.overview and len(media.overview) > 200 else media.overview,
|
||||
"vote_average": media.vote_average,
|
||||
"poster_path": media.poster_path,
|
||||
"backdrop_path": media.backdrop_path,
|
||||
"detail_link": media.detail_link
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到人物 ID {person_id} ({source}) 的参演作品"
|
||||
except Exception as e:
|
||||
error_message = f"搜索演员参演作品失败: {str(e)}"
|
||||
logger.error(f"搜索演员参演作品失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
127
app/agent/tools/impl/search_subscribe.py
Normal file
127
app/agent/tools/impl/search_subscribe.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""搜索订阅缺失剧集工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import global_vars
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchSubscribeInput(BaseModel):
|
||||
"""搜索订阅缺失剧集工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes")
|
||||
manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply for this search (optional, use query_rule_groups tool to get available rule groups. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
|
||||
|
||||
class SearchSubscribeTool(MoviePilotTool):
|
||||
name: str = "search_subscribe"
|
||||
description: str = "Search for missing episodes/resources for a specific subscription. This tool will search torrent sites for the missing episodes of the subscription and automatically download matching resources. Use this when a user wants to search for missing episodes of a specific subscription."
|
||||
args_schema: Type[BaseModel] = SearchSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
manual = kwargs.get("manual", False)
|
||||
|
||||
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
|
||||
if manual:
|
||||
message += "(手动搜索)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
|
||||
filter_groups: Optional[List[str]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
|
||||
|
||||
try:
|
||||
# 先验证订阅是否存在
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribe = subscribe_oper.get(subscribe_id)
|
||||
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 获取订阅信息用于返回
|
||||
subscribe_info = {
|
||||
"id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"year": subscribe.year,
|
||||
"type": subscribe.type,
|
||||
"season": subscribe.season,
|
||||
"state": subscribe.state,
|
||||
"total_episode": subscribe.total_episode,
|
||||
"lack_episode": subscribe.lack_episode,
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
}
|
||||
|
||||
# 检查订阅状态
|
||||
if subscribe.state == "S":
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 已暂停,无法搜索",
|
||||
"subscribe": subscribe_info
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 如果提供了 filter_groups 参数,先更新订阅的规则组
|
||||
if filter_groups is not None:
|
||||
subscribe_oper.update(subscribe_id, {"filter_groups": filter_groups})
|
||||
logger.info(f"更新订阅 #{subscribe_id} 的规则组为: {filter_groups}")
|
||||
|
||||
# 调用 SubscribeChain 的 search 方法
|
||||
# search 方法是同步的,需要在异步环境中运行
|
||||
subscribe_chain = SubscribeChain()
|
||||
|
||||
# 在线程池中执行同步的搜索操作
|
||||
# 当 sid 有值时,state 参数会被忽略,直接处理该订阅
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: subscribe_chain.search(
|
||||
sid=subscribe_id,
|
||||
state='R', # 默认状态,当 sid 有值时此参数会被忽略
|
||||
manual=manual
|
||||
)
|
||||
)
|
||||
|
||||
# 重新获取订阅信息以获取更新后的状态
|
||||
updated_subscribe = subscribe_oper.get(subscribe_id)
|
||||
if updated_subscribe:
|
||||
subscribe_info.update({
|
||||
"state": updated_subscribe.state,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"last_update": updated_subscribe.last_update,
|
||||
"filter_groups": updated_subscribe.filter_groups
|
||||
})
|
||||
|
||||
# 如果提供了规则组,会在返回信息中显示
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 搜索完成",
|
||||
"subscribe": subscribe_info
|
||||
}
|
||||
|
||||
if filter_groups is not None:
|
||||
result["message"] += f"(已应用规则组: {', '.join(filter_groups)})"
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索订阅缺失剧集失败: {str(e)}"
|
||||
logger.error(f"搜索订阅缺失剧集失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
@@ -79,7 +80,7 @@ class SearchTorrentsTool(MoviePilotTool):
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
if season is not None and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
@@ -99,7 +100,7 @@ class SearchTorrentsTool(MoviePilotTool):
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"size": StringUtils.format_size(t.torrent_info.size),
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
|
||||
182
app/agent/tools/impl/search_web.py
Normal file
182
app/agent/tools/impl/search_web.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Type, List, Dict
|
||||
|
||||
import httpx
|
||||
from ddgs import DDGS
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
# 搜索超时时间(秒)
|
||||
SEARCH_TIMEOUT = 20
|
||||
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
query: str = Field(..., description="The search query string to search for on the web")
|
||||
max_results: Optional[int] = Field(5,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
name: str = "search_web"
|
||||
description: str = "Search the web for information when you need to find current information, facts, or references that you're uncertain about. Returns search results with titles, snippets, and URLs. Use this tool to get up-to-date information from the internet."
|
||||
args_schema: Type[BaseModel] = SearchWebInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 5)
|
||||
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 5), 10)
|
||||
results = []
|
||||
|
||||
# 1. 优先使用 Tavily (如果配置了 API Key)
|
||||
if settings.TAVILY_API_KEY:
|
||||
logger.info("使用 Tavily 进行搜索...")
|
||||
results = await self._search_tavily(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
if not results:
|
||||
logger.info("使用 DuckDuckGo 进行搜索...")
|
||||
results = await self._search_duckduckgo(query, max_results)
|
||||
|
||||
if not results:
|
||||
return f"未找到与 '{query}' 相关的搜索结果"
|
||||
|
||||
# 格式化并裁剪结果
|
||||
formatted_results = self._format_and_truncate_results(results, max_results)
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索网络内容失败: {str(e)}"
|
||||
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def _search_tavily(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Tavily API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={
|
||||
"api_key": settings.TAVILY_API_KEY,
|
||||
"query": query,
|
||||
"search_depth": "basic",
|
||||
"max_results": max_results,
|
||||
"include_answer": False,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append({
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('content', ''),
|
||||
'url': result.get('url', ''),
|
||||
'source': 'Tavily'
|
||||
})
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_proxy_url(proxy_setting) -> Optional[str]:
|
||||
"""从代理设置中提取代理URL"""
|
||||
if not proxy_setting:
|
||||
return None
|
||||
if isinstance(proxy_setting, dict):
|
||||
return proxy_setting.get('http') or proxy_setting.get('https')
|
||||
return proxy_setting
|
||||
|
||||
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 duckduckgo-search (DDGS) 进行搜索"""
|
||||
try:
|
||||
def sync_search():
|
||||
results = []
|
||||
ddgs_kwargs = {
|
||||
'timeout': SEARCH_TIMEOUT
|
||||
}
|
||||
proxy_url = self._get_proxy_url(settings.PROXY)
|
||||
if proxy_url:
|
||||
ddgs_kwargs['proxy'] = proxy_url
|
||||
|
||||
try:
|
||||
with DDGS(**ddgs_kwargs) as ddgs:
|
||||
ddgs_gen = ddgs.text(
|
||||
query,
|
||||
max_results=max_results
|
||||
)
|
||||
if ddgs_gen:
|
||||
for result in ddgs_gen:
|
||||
results.append({
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('body', ''),
|
||||
'url': result.get('href', ''),
|
||||
'source': 'DuckDuckGo'
|
||||
})
|
||||
except Exception as err:
|
||||
logger.warning(f"DuckDuckGo search process failed: {err}")
|
||||
return results
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, sync_search)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
|
||||
"""格式化并裁剪搜索结果"""
|
||||
formatted = {
|
||||
"total_results": len(results),
|
||||
"results": []
|
||||
}
|
||||
|
||||
for idx, result in enumerate(results[:max_results], 1):
|
||||
title = result.get("title", "")[:200]
|
||||
snippet = result.get("snippet", "")
|
||||
url = result.get("url", "")
|
||||
source = result.get("source", "Unknown")
|
||||
|
||||
# 裁剪摘要
|
||||
max_snippet_length = 500 # 增加到500字符,提供更多上下文
|
||||
if len(snippet) > max_snippet_length:
|
||||
snippet = snippet[:max_snippet_length] + "..."
|
||||
|
||||
# 清理文本
|
||||
snippet = re.sub(r'\s+', ' ', snippet).strip()
|
||||
|
||||
formatted["results"].append({
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source
|
||||
})
|
||||
|
||||
if len(results) > max_results:
|
||||
formatted["note"] = f"仅显示前 {max_results} 条结果。"
|
||||
|
||||
return formatted
|
||||
@@ -20,7 +20,7 @@ class UpdateSiteInput(BaseModel):
|
||||
site_id: int = Field(..., description="The ID of the site to update")
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, higher number = higher priority)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||
@@ -39,7 +39,7 @@ class UpdateSiteInput(BaseModel):
|
||||
|
||||
class UpdateSiteTool(MoviePilotTool):
|
||||
name: str = "update_site"
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once."
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -29,7 +29,7 @@ class UpdateSubscribeInput(BaseModel):
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
|
||||
274
app/agent/tools/manager.py
Normal file
274
app/agent/tools/manager.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
"""
|
||||
工具定义
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.input_schema = input_schema
|
||||
|
||||
|
||||
class MoviePilotToolsManager:
|
||||
"""
|
||||
MoviePilot工具管理器(用于HTTP API)
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
"""
|
||||
加载所有MoviePilot工具
|
||||
"""
|
||||
try:
|
||||
# 创建工具实例
|
||||
self.tools = MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {e}", exc_info=True)
|
||||
self.tools = []
|
||||
|
||||
def list_tools(self) -> List[ToolDefinition]:
|
||||
"""
|
||||
列出所有可用的工具
|
||||
|
||||
Returns:
|
||||
工具定义列表
|
||||
"""
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
# 获取工具的输入参数模型
|
||||
args_schema = getattr(tool, 'args_schema', None)
|
||||
if args_schema:
|
||||
# 将Pydantic模型转换为JSON Schema
|
||||
input_schema = self._convert_to_json_schema(args_schema)
|
||||
else:
|
||||
# 如果没有args_schema,使用基本信息
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
|
||||
tools_list.append(ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema
|
||||
))
|
||||
|
||||
return tools_list
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Any]:
|
||||
"""
|
||||
获取指定工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具实例,如果未找到返回None
|
||||
"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据工具的参数schema规范化参数类型
|
||||
|
||||
Args:
|
||||
tool_instance: 工具实例
|
||||
arguments: 原始参数
|
||||
|
||||
Returns:
|
||||
规范化后的参数
|
||||
"""
|
||||
# 获取工具的参数schema
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
properties = schema.get("properties", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
if key not in properties:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
field_info = properties[key]
|
||||
field_type = field_info.get("type")
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
# 从 anyOf 中提取实际类型
|
||||
for type_option in any_of:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
|
||||
normalized[key] = None
|
||||
elif field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
|
||||
normalized[key] = None
|
||||
elif field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
normalized[key] = value.lower() in ("true", "1", "yes", "on")
|
||||
elif isinstance(value, (int, float)):
|
||||
normalized[key] = value != 0
|
||||
else:
|
||||
normalized[key] = True
|
||||
else:
|
||||
normalized[key] = value
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果(字符串)
|
||||
"""
|
||||
tool_instance = self.get_tool(tool_name)
|
||||
|
||||
if not tool_instance:
|
||||
error_msg = json.dumps({
|
||||
"error": f"工具 '{tool_name}' 未找到"
|
||||
}, ensure_ascii=False)
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
try:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示")
|
||||
formated_result = str(result)
|
||||
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
return error_msg
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将Pydantic模型转换为JSON Schema
|
||||
|
||||
Args:
|
||||
args_schema: Pydantic模型类
|
||||
|
||||
Returns:
|
||||
JSON Schema字典
|
||||
"""
|
||||
# 获取Pydantic模型的字段信息
|
||||
schema = args_schema.model_json_schema()
|
||||
|
||||
# 构建JSON Schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
if "properties" in schema:
|
||||
for field_name, field_info in schema["properties"].items():
|
||||
# 转换字段类型
|
||||
field_type = field_info.get("type", "string")
|
||||
field_description = field_info.get("description", "")
|
||||
|
||||
# 处理可选字段
|
||||
if field_name not in schema.get("required", []):
|
||||
# 可选字段
|
||||
default_value = field_info.get("default")
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
}
|
||||
if default_value is not None:
|
||||
properties[field_name]["default"] = default_value
|
||||
else:
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
}
|
||||
required.append(field_name)
|
||||
|
||||
# 处理枚举类型
|
||||
if "enum" in field_info:
|
||||
properties[field_name]["enum"] = field_info["enum"]
|
||||
|
||||
# 处理数组类型
|
||||
if field_type == "array" and "items" in field_info:
|
||||
properties[field_name]["items"] = field_info["items"]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
|
||||
moviepilot_tool_manager = MoviePilotToolsManager()
|
||||
@@ -2,11 +2,12 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
api_router.include_router(user.router, prefix="/user", tags=["user"])
|
||||
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
|
||||
api_router.include_router(site.router, prefix="/site", tags=["site"])
|
||||
api_router.include_router(message.router, prefix="/message", tags=["message"])
|
||||
api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"])
|
||||
@@ -28,3 +29,4 @@ api_router.include_router(discover.router, prefix="/discover", tags=["discover"]
|
||||
api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"])
|
||||
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
|
||||
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
|
||||
|
||||
@@ -6,12 +6,13 @@ from app import schemas
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import MediaInfo, Context, TorrentInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.schemas.types import ChainEventType, SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -66,8 +67,8 @@ def add(
|
||||
torrent_in: schemas.TorrentInfo,
|
||||
tmdbid: Annotated[int | None, Body()] = None,
|
||||
doubanid: Annotated[str | None, Body()] = None,
|
||||
bangumiid: Annotated[int | None, Body()] = None,
|
||||
downloader: Annotated[str | None, Body()] = None,
|
||||
# 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
save_path: Annotated[str | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
"""
|
||||
@@ -76,9 +77,13 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid)
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
mediainfo = MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
@@ -88,6 +93,7 @@ def add(
|
||||
media_info=mediainfo,
|
||||
torrent_info=torrentinfo
|
||||
)
|
||||
|
||||
did = DownloadChain().download_single(context=context, username=current_user.name,
|
||||
downloader=downloader, save_path=save_path, source="Manual")
|
||||
if not did:
|
||||
|
||||
@@ -4,6 +4,7 @@ import jieba
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
|
||||
from app import schemas
|
||||
from app.chain.storage import StorageChain
|
||||
@@ -11,7 +12,7 @@ from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.schemas.types import EventType
|
||||
@@ -98,6 +99,8 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
state = StorageChain().delete_media_file(src_fileitem)
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=f"{src_fileitem.path} 删除失败")
|
||||
# 删除下载记录中关联的文件
|
||||
DownloadFiles.delete_by_fullpath(db, Path(src_fileitem.path).as_posix())
|
||||
# 发送事件
|
||||
eventmanager.send_event(
|
||||
EventType.DownloadFileDeleted,
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.wallpaper import WallpaperHelper
|
||||
from app.helper.image import WallpaperHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
@@ -29,7 +29,14 @@ def login_access_token(
|
||||
mfa_code=otp_password)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
# 如果是需要MFA验证,返回特殊标识
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 用户等级
|
||||
level = SitesHelper().auth_level
|
||||
@@ -50,7 +57,7 @@ def login_access_token(
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions=user_or_message.permissions or {},
|
||||
widzard=show_wizard
|
||||
wizard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
353
app/api/endpoints/mcp.py
Normal file
353
app/api/endpoints/mcp.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import moviepilot_tool_manager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
# 导入版本号
|
||||
try:
|
||||
from version import APP_VERSION
|
||||
except ImportError:
|
||||
APP_VERSION = "unknown"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# MCP 协议版本
|
||||
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
创建 JSON-RPC 成功响应
|
||||
"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": result
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[
|
||||
str, Any]:
|
||||
"""
|
||||
创建 JSON-RPC 错误响应
|
||||
"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message
|
||||
}
|
||||
}
|
||||
if data is not None:
|
||||
error["error"]["data"] = data
|
||||
return error
|
||||
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
MCP 标准 JSON-RPC 2.0 端点
|
||||
|
||||
处理所有 MCP 协议消息(初始化、工具列表、工具调用等)
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
logger.error(f"解析请求体失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(None, -32700, "Parse error", str(e))
|
||||
)
|
||||
|
||||
# 验证 JSON-RPC 格式
|
||||
if not isinstance(body, dict) or body.get("jsonrpc") != "2.0":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(body.get("id"), -32600, "Invalid Request")
|
||||
)
|
||||
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
# 如果有 id,则为请求;没有 id 则为通知
|
||||
is_notification = request_id is None
|
||||
|
||||
try:
|
||||
# 处理初始化请求
|
||||
if method == "initialize":
|
||||
result = await handle_initialize(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理已初始化通知
|
||||
elif method == "notifications/initialized":
|
||||
if is_notification:
|
||||
return Response(status_code=204)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "initialized must be a notification"}
|
||||
)
|
||||
|
||||
# 处理工具列表请求
|
||||
if method == "tools/list":
|
||||
result = await handle_tools_list()
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理工具调用请求
|
||||
elif method == "tools/call":
|
||||
result = await handle_tools_call(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理 ping 请求
|
||||
elif method == "ping":
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, {}))
|
||||
|
||||
# 未知方法
|
||||
else:
|
||||
return JSONResponse(
|
||||
content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}")
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"MCP 请求参数错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MCP 请求失败: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_jsonrpc_error(request_id, -32603, "Internal error", str(e))
|
||||
)
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
处理初始化请求
|
||||
"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}")
|
||||
|
||||
# 版本协商:选择客户端和服务器都支持的版本
|
||||
negotiated_version = MCP_PROTOCOL_VERSION
|
||||
if protocol_version in MCP_PROTOCOL_VERSIONS:
|
||||
# 客户端版本在支持列表中,使用客户端版本
|
||||
negotiated_version = protocol_version
|
||||
logger.info(f"使用客户端协议版本: {negotiated_version}")
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": False # 暂不支持工具列表变更通知
|
||||
},
|
||||
"logging": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "MoviePilot",
|
||||
"version": APP_VERSION,
|
||||
"description": "MoviePilot MCP Server - 电影自动化管理工具",
|
||||
},
|
||||
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。"
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""
|
||||
处理工具列表请求
|
||||
"""
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
for tool in tools:
|
||||
mcp_tool = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
mcp_tools.append(mcp_tool)
|
||||
|
||||
return {
|
||||
"tools": mcp_tools
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
处理工具调用请求
|
||||
"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
try:
|
||||
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result_text
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"工具调用失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误: {str(e)}"
|
||||
}
|
||||
],
|
||||
"isError": True
|
||||
}
|
||||
|
||||
|
||||
@router.delete("", summary="终止 MCP 会话", response_model=None)
|
||||
async def delete_mcp_session(
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
终止 MCP 会话(无状态模式下仅返回成功)
|
||||
"""
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
async def list_tools(
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取所有可用的工具列表
|
||||
|
||||
返回每个工具的名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
# 获取所有工具定义
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
for tool in tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
tools_list.append(tool_dict)
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tools/call", summary="调用工具", response_model=schemas.ToolCallResponse)
|
||||
async def call_tool(
|
||||
request: schemas.ToolCallRequest,
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Any:
|
||||
"""
|
||||
调用指定的工具
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 调用工具
|
||||
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
success=True,
|
||||
result=result_text
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {request.tool_name} 失败: {e}", exc_info=True)
|
||||
return schemas.ToolCallResponse(
|
||||
success=False,
|
||||
error=f"调用工具失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/{tool_name}", summary="获取工具详情", response_model=Dict[str, Any])
|
||||
async def get_tool_info(
|
||||
tool_name: str,
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定工具的详细信息
|
||||
|
||||
Returns:
|
||||
工具的详细信息,包括名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
# 获取所有工具
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"工具 '{tool_name}' 未找到")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具信息失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具信息失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tools/{tool_name}/schema", summary="获取工具参数Schema", response_model=Dict[str, Any])
|
||||
async def get_tool_schema(
|
||||
tool_name: str,
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定工具的参数Schema(JSON Schema格式)
|
||||
|
||||
Returns:
|
||||
工具的JSON Schema定义
|
||||
"""
|
||||
try:
|
||||
# 获取所有工具
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return tool.input_schema
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"工具 '{tool_name}' 未找到")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具Schema失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具Schema失败: {str(e)}")
|
||||
@@ -11,7 +11,10 @@ from app.core.context import Context
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo, MetaInfoPath
|
||||
from app.core.security import verify_token, verify_apitoken
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_superuser
|
||||
from app.schemas import MediaType, MediaRecognizeConvertEventData
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import ChainEventType
|
||||
|
||||
router = APIRouter()
|
||||
@@ -85,25 +88,26 @@ async def search(title: str,
|
||||
return obj.get("source")
|
||||
return obj.source
|
||||
|
||||
result = []
|
||||
media_chain = MediaChain()
|
||||
if type == "media":
|
||||
_, medias = await media_chain.async_search(title=title)
|
||||
if medias:
|
||||
result = [media.to_dict() for media in medias]
|
||||
result = [media.to_dict() for media in medias] if medias else []
|
||||
elif type == "collection":
|
||||
result = await media_chain.async_search_collections(name=title)
|
||||
else:
|
||||
result = await media_chain.async_search_persons(name=title)
|
||||
if result:
|
||||
# 按设置的顺序对结果进行排序
|
||||
setting_order = settings.SEARCH_SOURCE.split(',') or []
|
||||
sort_order = {}
|
||||
for index, source in enumerate(setting_order):
|
||||
sort_order[source] = index
|
||||
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
|
||||
return result[(page - 1) * count:page * count]
|
||||
return []
|
||||
collections = await media_chain.async_search_collections(name=title)
|
||||
result = [collection.to_dict() for collection in collections] if collections else []
|
||||
else: # person
|
||||
persons = await media_chain.async_search_persons(name=title)
|
||||
result = [person.model_dump() for person in persons] if persons else []
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# 排序和分页
|
||||
setting_order = settings.SEARCH_SOURCE.split(',') if settings.SEARCH_SOURCE else []
|
||||
sort_order = {source: index for index, source in enumerate(setting_order)}
|
||||
|
||||
sorted_result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
|
||||
return sorted_result[(page - 1) * count:page * count]
|
||||
|
||||
|
||||
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)
|
||||
@@ -130,6 +134,26 @@ def scrape(fileitem: schemas.FileItem,
|
||||
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
|
||||
|
||||
|
||||
@router.get("/category/config", summary="获取分类策略配置", response_model=schemas.Response)
|
||||
def get_category_config(_: User = Depends(get_current_active_user)):
|
||||
"""
|
||||
获取分类策略配置
|
||||
"""
|
||||
config = MediaChain().category_config()
|
||||
return schemas.Response(success=True, data=config.model_dump())
|
||||
|
||||
|
||||
@router.post("/category/config", summary="保存分类策略配置", response_model=schemas.Response)
|
||||
def save_category_config(config: CategoryConfig, _: User = Depends(get_current_active_superuser)):
|
||||
"""
|
||||
保存分类策略配置
|
||||
"""
|
||||
if MediaChain().save_category_config(config):
|
||||
return schemas.Response(success=True, message="保存成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message="保存失败")
|
||||
|
||||
|
||||
@router.get("/category", summary="查询自动分类配置", response_model=dict)
|
||||
async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
@@ -171,7 +195,7 @@ async def seasons(mediaid: Optional[str] = None,
|
||||
tmdbid = int(mediaid[5:])
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
|
||||
if seasons_info:
|
||||
if season:
|
||||
if season is not None:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
return seasons_info
|
||||
if title:
|
||||
@@ -183,11 +207,11 @@ async def seasons(mediaid: Optional[str] = None,
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
|
||||
if seasons_info:
|
||||
if season:
|
||||
if season is not None:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
return seasons_info
|
||||
else:
|
||||
sea = season or 1
|
||||
sea = season if season is not None else 1
|
||||
return [schemas.MediaSeason(
|
||||
season_number=sea,
|
||||
poster_path=mediainfo.poster_path,
|
||||
|
||||
@@ -54,7 +54,7 @@ async def exists_local(title: Optional[str] = None,
|
||||
判断本地是否存在
|
||||
"""
|
||||
meta = MetaInfo(title)
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
# 返回对象
|
||||
ret_info = {}
|
||||
@@ -82,8 +82,8 @@ def exists(media_in: schemas.MediaInfo,
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
if not existsinfo:
|
||||
return []
|
||||
if media_in.season:
|
||||
return {}
|
||||
if media_in.season is not None:
|
||||
return {
|
||||
media_in.season: existsinfo.seasons.get(media_in.season) or []
|
||||
}
|
||||
@@ -101,7 +101,7 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
mtype = MediaType(media_in.type) if media_in.type else None
|
||||
if mtype:
|
||||
meta.type = mtype
|
||||
if media_in.season:
|
||||
if media_in.season is not None:
|
||||
meta.begin_season = media_in.season
|
||||
meta.type = MediaType.TV
|
||||
if media_in.year:
|
||||
|
||||
498
app/api/endpoints/mfa.py
Normal file
498
app/api/endpoints/mfa.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
MFA (Multi-Factor Authentication) API 端点
|
||||
包含 OTP 和 PassKey 相关功能
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db import get_async_db
|
||||
from app.db.models.passkey import PassKey
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_user_async
|
||||
from app.helper.passkey import PassKeyHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
def _build_credential_list(passkeys: list[PassKey]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
构建凭证列表
|
||||
|
||||
:param passkeys: PassKey 列表
|
||||
:return: 凭证字典列表
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
|
||||
def _extract_and_standardize_credential_id(credential: dict) -> str:
|
||||
"""
|
||||
从凭证中提取并标准化 credential_id
|
||||
|
||||
:param credential: 凭证字典
|
||||
:return: 标准化后的 credential_id
|
||||
:raises ValueError: 如果凭证无效
|
||||
"""
|
||||
credential_id_raw = credential.get('id') or credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise ValueError("无效的凭证")
|
||||
return PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
|
||||
def _verify_passkey_and_update(
|
||||
credential: dict,
|
||||
challenge: str,
|
||||
passkey: PassKey
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
验证 PassKey 并更新使用时间和签名计数
|
||||
|
||||
:param credential: 凭证字典
|
||||
:param challenge: 挑战值
|
||||
:param passkey: PassKey 对象
|
||||
:return: (验证是否成功, 新的签名计数)
|
||||
"""
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_challenge=challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if success:
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
return success, new_sign_count
|
||||
|
||||
|
||||
async def _check_user_has_passkey(db: AsyncSession, user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否有 PassKey
|
||||
|
||||
:param db: 数据库会话
|
||||
:param user_id: 用户 ID
|
||||
:return: 是否有 PassKey
|
||||
"""
|
||||
return bool(await PassKey.async_get_by_user_id(db=db, user_id=user_id))
|
||||
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
"""OTP验证请求"""
|
||||
uri: str
|
||||
otpPassword: str
|
||||
|
||||
class OtpDisableRequest(schemas.BaseModel):
|
||||
"""OTP禁用请求"""
|
||||
password: str
|
||||
|
||||
class PassKeyDeleteRequest(schemas.BaseModel):
|
||||
"""PassKey删除请求"""
|
||||
passkey_id: int
|
||||
password: str
|
||||
|
||||
# ==================== 通用 MFA 接口 ====================
|
||||
|
||||
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
|
||||
async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
检查指定用户是否启用了任何双重验证方式(OTP 或 PassKey)
|
||||
"""
|
||||
user: User = await User.async_get_by_name(db, username)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
|
||||
# 检查是否启用了OTP
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = await _check_user_has_passkey(db, user.id)
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
|
||||
|
||||
# ==================== OTP 相关接口 ====================
|
||||
|
||||
@router.post('/otp/generate', summary='生成 OTP 验证 URI', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""生成 OTP 密钥及对应的 URI"""
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
|
||||
async def otp_verify(
|
||||
data: OtpVerifyRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
|
||||
if not OtpUtils.is_legal(data.uri, data.otpPassword):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
data: OtpDisableRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,默认不允许关闭 OTP,除非配置允许
|
||||
has_passkey = await _check_user_has_passkey(db, current_user.id)
|
||||
if has_passkey and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
# ==================== PassKey 相关接口 ====================
|
||||
|
||||
class PassKeyRegistrationStart(schemas.BaseModel):
|
||||
"""PassKey注册开始请求"""
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyRegistrationFinish(schemas.BaseModel):
|
||||
"""PassKey注册完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyAuthenticationStart(schemas.BaseModel):
|
||||
"""PassKey认证开始请求"""
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class PassKeyAuthenticationFinish(schemas.BaseModel):
|
||||
"""PassKey认证完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
|
||||
|
||||
@router.post("/passkey/register/start", summary="开始注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_start(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""开始注册 PassKey - 生成注册选项"""
|
||||
try:
|
||||
# 安全检查:默认需要先启用 OTP,除非配置允许在未启用 OTP 时注册
|
||||
if not current_user.is_otp and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥"
|
||||
)
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = _build_credential_list(existing_passkeys) if existing_passkeys else None
|
||||
|
||||
# 生成注册选项
|
||||
options_json, challenge = PassKeyHelper.generate_registration_options(
|
||||
user_id=current_user.id,
|
||||
username=current_user.name,
|
||||
display_name=current_user.settings.get('nickname') if current_user.settings else None,
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey注册选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成注册选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/register/finish", summary="完成注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_finish(
|
||||
passkey_req: PassKeyRegistrationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""完成注册 PassKey - 验证并保存凭证"""
|
||||
try:
|
||||
# 验证注册响应
|
||||
credential_id, public_key, sign_count, aaguid = PassKeyHelper.verify_registration_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge
|
||||
)
|
||||
|
||||
# 提取transports
|
||||
transports = None
|
||||
if 'response' in passkey_req.credential and 'transports' in passkey_req.credential['response']:
|
||||
transports = ','.join(passkey_req.credential['response']['transports'])
|
||||
|
||||
# 保存到数据库
|
||||
passkey = PassKey(
|
||||
user_id=current_user.id,
|
||||
credential_id=credential_id,
|
||||
public_key=public_key,
|
||||
sign_count=sign_count,
|
||||
name=passkey_req.name or "通行密钥",
|
||||
aaguid=aaguid,
|
||||
transports=transports
|
||||
)
|
||||
passkey.create()
|
||||
|
||||
logger.info(f"用户 {current_user.name} 成功注册PassKey: {passkey_req.name}")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥注册成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"注册PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"注册失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/start", summary="开始 PassKey 认证", response_model=schemas.Response)
|
||||
def passkey_authenticate_start(
|
||||
passkey_req: PassKeyAuthenticationStart = Body(...)
|
||||
) -> Any:
|
||||
"""开始 PassKey 认证 - 生成认证选项"""
|
||||
try:
|
||||
existing_credentials = None
|
||||
|
||||
# 如果指定了用户名,只允许该用户的PassKey
|
||||
if passkey_req.username:
|
||||
user = User.get_by_name(db=None, name=passkey_req.username)
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if user else None
|
||||
|
||||
if not user or not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
existing_credentials = _build_credential_list(existing_passkeys)
|
||||
|
||||
# 生成认证选项
|
||||
options_json, challenge = PassKeyHelper.generate_authentication_options(
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey认证选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/finish", summary="完成 PassKey 认证", response_model=schemas.Token)
|
||||
def passkey_authenticate_finish(
|
||||
passkey_req: PassKeyAuthenticationFinish
|
||||
) -> Any:
|
||||
"""完成 PassKey 认证 - 验证凭证并返回 token"""
|
||||
try:
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey认证失败,提供的凭证无效: {e}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 查找PassKey并获取用户
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id) if passkey else None
|
||||
if not passkey or not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
# 生成token
|
||||
level = SitesHelper().auth_level
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
level=level
|
||||
),
|
||||
token_type="bearer",
|
||||
super_user=user.is_superuser,
|
||||
user_id=user.id,
|
||||
user_name=user.name,
|
||||
avatar=user.avatar,
|
||||
level=level,
|
||||
permissions=user.permissions or {},
|
||||
wizard=show_wizard
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey认证失败: {e}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
|
||||
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
|
||||
def passkey_list(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""获取当前用户的所有 PassKey"""
|
||||
try:
|
||||
passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
|
||||
key_list = [
|
||||
{
|
||||
'id': pk.id,
|
||||
'name': pk.name,
|
||||
'created_at': pk.created_at.isoformat() if pk.created_at else None,
|
||||
'last_used_at': pk.last_used_at.isoformat() if pk.last_used_at else None,
|
||||
'aaguid': pk.aaguid,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=key_list
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取PassKey列表失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"获取列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
|
||||
async def passkey_delete(
|
||||
data: PassKeyDeleteRequest,
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""删除指定的 PassKey"""
|
||||
try:
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
|
||||
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥已删除"
|
||||
)
|
||||
else:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或无权删除"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"删除PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"删除失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/verify", summary="PassKey 二次验证", response_model=schemas.Response)
|
||||
def passkey_verify_mfa(
|
||||
passkey_req: PassKeyAuthenticationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""使用 PassKey 进行二次验证(MFA)"""
|
||||
try:
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey二次验证失败,提供的凭证无效: {e}")
|
||||
return schemas.Response(success=False, message="验证失败")
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey or passkey.user_id != current_user.id:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或不属于当前用户"
|
||||
)
|
||||
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥验证失败"
|
||||
)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="二次验证成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey二次验证失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="验证失败"
|
||||
)
|
||||
@@ -1,14 +1,16 @@
|
||||
from typing import List, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.ai_recommend import AIRecommendChain
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.log import logger
|
||||
from app.schemas import MediaRecognizeConvertEventData
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
|
||||
@@ -36,6 +38,9 @@ async def search_by_id(mediaid: str,
|
||||
"""
|
||||
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
|
||||
"""
|
||||
# 取消正在运行的AI推荐(会清除数据库缓存)
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
if mtype:
|
||||
media_type = MediaType(mtype)
|
||||
else:
|
||||
@@ -159,6 +164,9 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
"""
|
||||
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
|
||||
"""
|
||||
# 取消正在运行的AI推荐并清除数据库缓存
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=[int(site) for site in sites.split(",") if site] if sites else None,
|
||||
@@ -167,3 +175,87 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
if not torrents:
|
||||
return schemas.Response(success=False, message="未搜索到任何资源")
|
||||
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
|
||||
|
||||
|
||||
@router.post("/recommend", summary="AI推荐资源", response_model=schemas.Response)
|
||||
async def recommend_search_results(
|
||||
filtered_indices: Optional[List[int]] = Body(None, embed=True, description="筛选后的索引列表"),
|
||||
check_only: bool = Body(False, embed=True, description="仅检查状态,不启动新任务"),
|
||||
force: bool = Body(False, embed=True, description="强制重新推荐,清除旧结果"),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
AI推荐资源 - 轮询接口
|
||||
前端轮询此接口,发送筛选后的索引(如果有筛选)
|
||||
后端根据请求变化自动取消旧任务并启动新任务
|
||||
|
||||
参数:
|
||||
- filtered_indices: 筛选后的索引列表(可选,为空或不提供时使用所有结果)
|
||||
- check_only: 仅检查状态(首次打开页面时使用,避免触发不必要的重新推理)
|
||||
- force: 强制重新推荐(清除旧结果并重新启动)
|
||||
|
||||
返回数据结构:
|
||||
{
|
||||
"success": bool,
|
||||
"message": string, // 错误信息(仅在错误时存在)
|
||||
"data": {
|
||||
"status": string, // 状态: disabled | idle | running | completed | error
|
||||
"results": array // 推荐结果(仅status=completed时存在)
|
||||
}
|
||||
}
|
||||
"""
|
||||
# 从缓存获取上次搜索结果
|
||||
results = await SearchChain().async_last_search_results() or []
|
||||
if not results:
|
||||
return schemas.Response(success=False, message="没有可用的搜索结果", data={
|
||||
"status": "error"
|
||||
})
|
||||
|
||||
recommend_chain = AIRecommendChain()
|
||||
|
||||
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
|
||||
if force:
|
||||
# 检查功能是否启用
|
||||
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "disabled"
|
||||
})
|
||||
logger.info("收到新推荐请求,清除旧结果并启动新任务")
|
||||
recommend_chain.cancel_ai_recommend()
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 直接返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测)
|
||||
if check_only:
|
||||
# 返回当前运行状态,不做任何任务启动或取消操作
|
||||
current_status = recommend_chain.get_current_status_only()
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if current_status.get("status") == "error":
|
||||
error_msg = current_status.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=current_status)
|
||||
return schemas.Response(success=True, data=current_status)
|
||||
|
||||
# 获取当前状态(会检测请求是否变化)
|
||||
status_data = recommend_chain.get_status(filtered_indices, len(results))
|
||||
|
||||
# 如果功能未启用,直接返回禁用状态
|
||||
if status_data.get("status") == "disabled":
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
# 如果是空闲状态,启动新任务
|
||||
if status_data["status"] == "idle":
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 立即返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if status_data.get("status") == "error":
|
||||
error_msg = status_data.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=status_data)
|
||||
|
||||
# 返回当前状态
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
@@ -92,10 +92,14 @@ async def update_site(
|
||||
# 校正地址格式
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
|
||||
site_in.url = f"{_scheme}://{_netloc}/"
|
||||
site_in.domain = StringUtils.get_url_domain(site_in.url)
|
||||
await site.async_update(db, site_in.model_dump())
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": site_in.domain
|
||||
"site_id": site_in.id,
|
||||
"domain": site_in.domain,
|
||||
"name": site_in.name,
|
||||
"site_url": site_in.url
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@@ -31,6 +31,17 @@ def qrcode(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return schemas.Response(success=False, message=errmsg)
|
||||
|
||||
|
||||
@router.get("/auth_url/{name}", summary="获取 OAuth2 授权 URL", response_model=schemas.Response)
|
||||
def auth_url(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取 OAuth2 授权 URL
|
||||
"""
|
||||
auth_data, errmsg = StorageChain().generate_auth_url(name)
|
||||
if auth_data:
|
||||
return schemas.Response(success=True, data=auth_data)
|
||||
return schemas.Response(success=False, message=errmsg)
|
||||
|
||||
|
||||
@router.get("/check/{name}", summary="二维码登录确认", response_model=schemas.Response)
|
||||
def check(name: str, ck: Optional[str] = None, t: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@@ -83,7 +94,7 @@ def list_files(fileitem: schemas.FileItem,
|
||||
if sort == "name":
|
||||
file_list.sort(key=lambda x: StringUtils.natural_sort_key(x.name or ""))
|
||||
else:
|
||||
file_list.sort(key=lambda x: x.modify_time or datetime.min, reverse=True)
|
||||
file_list.sort(key=lambda x: x.modify_time or -math.inf, reverse=True)
|
||||
return file_list
|
||||
|
||||
|
||||
@@ -167,7 +178,7 @@ def rename(fileitem: schemas.FileItem,
|
||||
# 重命名目录内文件
|
||||
if recursive:
|
||||
transferchain = TransferChain()
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
# 递归修改目录内文件(智能识别命名)
|
||||
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
|
||||
if sub_files:
|
||||
|
||||
@@ -199,7 +199,7 @@ async def subscribe_mediaid(
|
||||
# 使用名称检查订阅
|
||||
if title_check and title:
|
||||
meta = MetaInfo(title)
|
||||
if season:
|
||||
if season is not None:
|
||||
meta.begin_season = season
|
||||
result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)
|
||||
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Annotated
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
@@ -19,7 +16,6 @@ from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.cache import AsyncFileCache
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -29,12 +25,14 @@ from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.system import SystemHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas import ConfigChangeEventData
|
||||
@@ -50,7 +48,7 @@ router = APIRouter()
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: bool = False,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
@@ -69,77 +67,24 @@ async def fetch_image(
|
||||
logger.warn(f"Blocked unsafe image URL: {url}")
|
||||
return None
|
||||
|
||||
# 缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path("images") / sanitized_path
|
||||
if not cache_path.suffix:
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 缓存对像,缓存过期时间为全局图片缓存天数
|
||||
cache_backend = AsyncFileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
if use_cache:
|
||||
content = await cache_backend.get(cache_path.as_posix(), region="images")
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if proxy else None
|
||||
response = await AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=proxies,
|
||||
referer=referer,
|
||||
content = await ImageHelper().async_fetch_image(
|
||||
url=url,
|
||||
proxy=proxy,
|
||||
use_cache=use_cache,
|
||||
cookies=cookies,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*",
|
||||
).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
content = response.content
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
except Exception as e:
|
||||
logger.warn(f"Invalid image format for URL {url}: {e}")
|
||||
return None
|
||||
|
||||
# 获取请求响应头
|
||||
response_headers = response.headers
|
||||
cache_control_header = response_headers.get("Cache-Control", "")
|
||||
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
|
||||
|
||||
# 保存缓存
|
||||
if use_cache:
|
||||
await cache_backend.set(cache_path.as_posix(), content, region="images")
|
||||
logger.debug(f"Image cached at {cache_path.as_posix()}")
|
||||
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
if if_none_match == etag:
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(status_code=304, headers=headers)
|
||||
|
||||
# 响应
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
@@ -177,8 +122,7 @@ async def cache_img(
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
proxy = "doubanio.com" not in url
|
||||
return await fetch_image(url=url, proxy=proxy, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
|
||||
|
||||
@@ -186,22 +130,53 @@ async def cache_img(
|
||||
def get_global_setting(token: str):
|
||||
"""
|
||||
查询非敏感系统设置(默认鉴权)
|
||||
仅包含登录前UI初始化必需的字段
|
||||
"""
|
||||
if token != "moviepilot":
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
# 白名单模式,仅包含登录前UI初始化必需的字段
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
include={
|
||||
"TMDB_IMAGE_DOMAIN",
|
||||
"GLOBAL_IMAGE_CACHE",
|
||||
"ADVANCED_MODE",
|
||||
}
|
||||
)
|
||||
# 追加版本信息(用于版本检查)
|
||||
info.update({
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
|
||||
|
||||
@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response)
|
||||
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询用户相关系统设置(登录后获取)
|
||||
包含业务功能相关的配置和用户权限信息
|
||||
"""
|
||||
# 业务功能相关的配置字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP"
|
||||
}
|
||||
)
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
info["AI_RECOMMEND_ENABLED"] = False
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
@@ -247,13 +222,11 @@ async def set_env_setting(env: dict,
|
||||
)
|
||||
|
||||
if success_updates:
|
||||
for key in success_updates.keys():
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=getattr(settings, key, None),
|
||||
change_type="update"
|
||||
))
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=success_updates.keys(),
|
||||
change_type="update"
|
||||
))
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
@@ -338,6 +311,18 @@ async def set_setting(
|
||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(provider: str, api_key: str, base_url: Optional[str] = None, _: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
"""
|
||||
try:
|
||||
models = LLMHelper().get_models(provider, api_key, base_url)
|
||||
return schemas.Response(success=True, data=models)
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(request: Request, role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
@@ -630,7 +615,10 @@ def run_scheduler(jobid: str,
|
||||
"""
|
||||
if not jobid:
|
||||
return schemas.Response(success=False, message="命令不能为空!")
|
||||
Scheduler().start(jobid)
|
||||
if jobid in {"recommend_refresh", "cookiecloud"}:
|
||||
Scheduler().start(jobid, manual=True)
|
||||
else:
|
||||
Scheduler().start(jobid)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -643,5 +631,8 @@ def run_scheduler2(jobid: str,
|
||||
if not jobid:
|
||||
return schemas.Response(success=False, message="命令不能为空!")
|
||||
|
||||
Scheduler().start(jobid)
|
||||
if jobid in {"recommend_refresh", "cookiecloud"}:
|
||||
Scheduler().start(jobid, manual=True)
|
||||
else:
|
||||
Scheduler().start(jobid)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -111,45 +111,6 @@ async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db),
|
||||
return schemas.Response(success=True, message=file.filename)
|
||||
|
||||
|
||||
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||
async def otp_judge(
|
||||
data: dict,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
|
||||
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
user: User = await User.async_get_by_name(db, userid)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
return schemas.Response(success=user.is_otp)
|
||||
|
||||
|
||||
@router.get("/config/{key}", summary="查询用户配置", response_model=schemas.Response)
|
||||
def get_config(key: str,
|
||||
current_user: User = Depends(get_current_active_user)):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Callable, Any, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -128,9 +128,12 @@ async def get_cookie(
|
||||
@cookie_router.post("/get/{uuid}")
|
||||
async def post_cookie(
|
||||
uuid: Annotated[str, Path(min_length=5, pattern="^[a-zA-Z0-9]+$")],
|
||||
request: schemas.CookiePassword):
|
||||
request: Optional[schemas.CookiePassword] = Body(None)):
|
||||
"""
|
||||
POST 下载加密数据
|
||||
"""
|
||||
data = await load_encrypt_data(uuid)
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
if request is not None:
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
else:
|
||||
return data
|
||||
|
||||
@@ -4,6 +4,7 @@ import pickle
|
||||
import traceback
|
||||
from abc import ABCMeta
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Tuple, List, Set, Union, Dict
|
||||
|
||||
@@ -25,6 +26,7 @@ from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \
|
||||
WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType, MessageChannel
|
||||
from app.utils.object import ObjectUtils
|
||||
|
||||
@@ -250,6 +252,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 中止继续执行
|
||||
break
|
||||
except Exception as err:
|
||||
logger.error(traceback.format_exc())
|
||||
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
@@ -291,6 +294,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 中止继续执行
|
||||
break
|
||||
except Exception as err:
|
||||
logger.error(traceback.format_exc())
|
||||
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
@@ -849,6 +853,8 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
@@ -932,6 +938,8 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
@@ -1055,6 +1063,18 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("media_category")
|
||||
|
||||
def category_config(self) -> CategoryConfig:
|
||||
"""
|
||||
获取分类策略配置
|
||||
"""
|
||||
return self.run_module("load_category_config")
|
||||
|
||||
def save_category_config(self, config: CategoryConfig) -> bool:
|
||||
"""
|
||||
保存分类策略配置
|
||||
"""
|
||||
return self.run_module("save_category_config", config=config)
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]) -> None:
|
||||
"""
|
||||
注册菜单命令
|
||||
|
||||
318
app/chain/ai_recommend.py
Normal file
318
app/chain/ai_recommend.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import re
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class AIRecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
AI推荐处理链,单例运行
|
||||
用于基于搜索结果的AI智能推荐
|
||||
"""
|
||||
|
||||
# 缓存文件名
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
# AI推荐状态
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_request_hash: Optional[str] = None # 当前请求的哈希值
|
||||
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存(索引列表)
|
||||
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
|
||||
|
||||
@staticmethod
|
||||
def _calculate_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算请求的哈希值,用于判断请求是否变化
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
def _build_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典
|
||||
:return: 状态字典
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if self._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
# 尝试从数据库加载缓存
|
||||
if self._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
self._ai_recommend_result = cached_indices
|
||||
|
||||
# 只要有结果,始终返回completed状态和数据
|
||||
if self._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": self._ai_recommend_result}
|
||||
|
||||
if self._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": self._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前状态(不校验hash,用于check_only模式)
|
||||
"""
|
||||
return self._build_status()
|
||||
|
||||
def get_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态并检查请求是否变化(用于首次请求或force模式)
|
||||
如果请求变化(筛选条件变化),返回idle状态
|
||||
"""
|
||||
# 计算当前请求的hash
|
||||
request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 检查请求是否变化
|
||||
is_same_request = request_hash == self._current_request_hash
|
||||
|
||||
# 如果请求变化了(筛选条件改变),返回idle状态
|
||||
if not is_same_request:
|
||||
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
|
||||
|
||||
# 请求未变化,返回当前实际状态
|
||||
return self._build_status()
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
|
||||
"""
|
||||
AI推荐
|
||||
:param items: 候选资源列表(JSON字符串格式)
|
||||
:param preference: 用户偏好(可选)
|
||||
:return: AI返回的推荐结果
|
||||
"""
|
||||
# 设置运行状态
|
||||
self._ai_recommend_running = True
|
||||
try:
|
||||
# 导入LLMHelper
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
# 获取LLM实例
|
||||
llm = LLMHelper.get_llm()
|
||||
|
||||
# 构建提示词
|
||||
user_preference = (
|
||||
preference
|
||||
or settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
|
||||
# 添加指令
|
||||
instruction = """
|
||||
Task: Select the best matching items from the list based on user preferences.
|
||||
|
||||
Each item contains:
|
||||
- index: Item number
|
||||
- title: Full torrent title
|
||||
- size: File size
|
||||
- seeders: Number of seeders
|
||||
|
||||
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
|
||||
"""
|
||||
message = (
|
||||
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
|
||||
+ "\n".join(items)
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
response = await llm.ainvoke(message)
|
||||
return response.content
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"AI推荐配置错误: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
finally:
|
||||
# 清除运行状态
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
def is_ai_recommend_running(self) -> bool:
|
||||
"""
|
||||
检查AI推荐是否正在运行
|
||||
"""
|
||||
return self._ai_recommend_running
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消正在运行的AI推荐任务
|
||||
"""
|
||||
if self._ai_recommend_task and not self._ai_recommend_task.done():
|
||||
self._ai_recommend_task.cancel()
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
self._current_request_hash = None
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务
|
||||
:param filtered_indices: 筛选后的索引列表
|
||||
:param search_results_count: 搜索结果总数
|
||||
:param results: 搜索结果列表
|
||||
"""
|
||||
# 防护检查:确保AI推荐功能已启用
|
||||
if not self.is_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
# 计算新请求的哈希值
|
||||
new_request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 如果请求变化了,取消旧任务
|
||||
if new_request_hash != self._current_request_hash:
|
||||
self.cancel_ai_recommend()
|
||||
|
||||
# 更新请求哈希值
|
||||
self._current_request_hash = new_request_hash
|
||||
|
||||
# 重置状态
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
|
||||
# 启动新任务
|
||||
async def run_recommend():
|
||||
# 获取当前任务对象,用于在finally中比对
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
self._ai_recommend_running = True
|
||||
|
||||
# 准备数据
|
||||
items = []
|
||||
valid_indices = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
|
||||
if filtered_indices is not None and len(filtered_indices) > 0:
|
||||
results_to_process = [
|
||||
results[i]
|
||||
for i in filtered_indices
|
||||
if 0 <= i < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for i, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(i)
|
||||
|
||||
item_info = {
|
||||
"index": i,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
if not items:
|
||||
self._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
# 调用AI推荐
|
||||
ai_response = await self.async_ai_recommend(items)
|
||||
|
||||
# 解析AI返回的索引
|
||||
try:
|
||||
# 使用正则提取JSON数组(非贪婪模式,避免匹配多个数组)
|
||||
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(ai_response)
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
# 映射回原始索引
|
||||
if filtered_indices:
|
||||
original_indices = [
|
||||
filtered_indices[valid_indices[i]]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= filtered_indices[valid_indices[i]] < len(results)
|
||||
]
|
||||
else:
|
||||
original_indices = [
|
||||
valid_indices[i]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= valid_indices[i] < len(results)
|
||||
]
|
||||
|
||||
# 只返回索引列表,不返回完整数据
|
||||
self._ai_recommend_result = original_indices
|
||||
|
||||
# 保存到数据库
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
|
||||
)
|
||||
self._ai_recommend_error = str(e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"AI推荐任务失败: {e}")
|
||||
self._ai_recommend_error = str(e)
|
||||
finally:
|
||||
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
|
||||
# 如果任务被取消并启动了新任务,self._ai_recommend_task 已经指向新任务,不应重置
|
||||
if self._ai_recommend_task == current_task:
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
# 创建并启动任务
|
||||
self._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
@@ -19,7 +19,7 @@ from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
from app.schemas import ExistMediaInfo, FileURI, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
ResourceDownloadEventData
|
||||
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ContentType, \
|
||||
ChainEventType
|
||||
@@ -162,7 +162,7 @@ class DownloadChain(ChainBase):
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、Subscribe、Manual等)
|
||||
:param downloader: 下载器
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param userid: 用户ID
|
||||
:param username: 调用下载的用户名/插件名
|
||||
:param label: 自定义标签
|
||||
@@ -232,13 +232,14 @@ class DownloadChain(ChainBase):
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
|
||||
storage = 'local'
|
||||
# 下载目录
|
||||
if save_path:
|
||||
# 下载目录使用自定义的
|
||||
download_dir = Path(save_path)
|
||||
else:
|
||||
# 根据媒体信息查询下载目录配置
|
||||
dir_info = DirectoryHelper().get_dir(_media, storage="local", include_unsorted=True)
|
||||
dir_info = DirectoryHelper().get_dir(_media, include_unsorted=True)
|
||||
storage = dir_info.storage if dir_info else storage
|
||||
# 拼装子目录
|
||||
if dir_info:
|
||||
# 一级目录
|
||||
@@ -259,6 +260,8 @@ class DownloadChain(ChainBase):
|
||||
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
|
||||
title="下载失败", role="system")
|
||||
return None
|
||||
fileURI = FileURI(storage=storage, path=download_dir.as_posix())
|
||||
download_dir = Path(fileURI.uri)
|
||||
|
||||
# 添加下载
|
||||
result: Optional[tuple] = self.download(content=torrent_content,
|
||||
@@ -324,9 +327,10 @@ class DownloadChain(ChainBase):
|
||||
if not file_meta.begin_episode \
|
||||
or file_meta.begin_episode not in episodes:
|
||||
continue
|
||||
# 只处理视频格式
|
||||
# 只处理音视频、字幕格式
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if not Path(file).suffix \
|
||||
or Path(file).suffix.lower() not in settings.RMT_MEDIAEXT:
|
||||
or Path(file).suffix.lower() not in media_exts:
|
||||
continue
|
||||
files_to_add.append({
|
||||
"download_hash": _hash,
|
||||
@@ -400,7 +404,7 @@ class DownloadChain(ChainBase):
|
||||
根据缺失数据,自动种子列表中组合择优下载
|
||||
:param contexts: 资源上下文列表
|
||||
:param no_exists: 缺失的剧集信息
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、订阅、手工下载等)
|
||||
:param userid: 用户ID
|
||||
|
||||
@@ -150,7 +150,7 @@ class MediaChain(ChainBase):
|
||||
org_meta.year = year
|
||||
org_meta.begin_season = season_number
|
||||
org_meta.begin_episode = episode_number
|
||||
if org_meta.begin_season or org_meta.begin_episode:
|
||||
if org_meta.begin_season is not None or org_meta.begin_episode is not None:
|
||||
org_meta.type = MediaType.TV
|
||||
# 重新识别
|
||||
return self.recognize_media(meta=org_meta)
|
||||
@@ -315,21 +315,6 @@ class MediaChain(ChainBase):
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_bluray_folder(fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断是否为原盘目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
# 蓝光原盘目录必备的文件或文件夹
|
||||
required_files = ['BDMV', 'CERTIFICATE']
|
||||
# 检查目录下是否存在所需文件或文件夹
|
||||
for item in StorageChain().list_files(fileitem):
|
||||
if item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
@eventmanager.register(EventType.MetadataScrape)
|
||||
def scrape_metadata_event(self, event: Event):
|
||||
"""
|
||||
@@ -370,7 +355,7 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
if file_list:
|
||||
# 如果是BDMV原盘目录,只对根目录进行刮削,不处理子目录
|
||||
if self.is_bluray_folder(fileitem):
|
||||
if storagechain.is_bluray_folder(fileitem):
|
||||
logger.info(f"检测到BDMV原盘目录,只对根目录进行刮削:{fileitem.path}")
|
||||
self.scrape_metadata(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
@@ -563,10 +548,23 @@ class MediaChain(ChainBase):
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 电影目录
|
||||
if recursive:
|
||||
# 处理文件
|
||||
if self.is_bluray_folder(fileitem):
|
||||
# 原盘目录
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
is_bluray_folder = storagechain.contains_bluray_subdirectories(files)
|
||||
if recursive and not is_bluray_folder:
|
||||
# 处理非原盘目录内的文件
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 生成目录内图片文件
|
||||
if init_folder:
|
||||
if is_bluray_folder:
|
||||
# 检查电影NFO开关
|
||||
if scraping_switchs.get('movie_nfo', True):
|
||||
nfo_path = filepath / (filepath.name + ".nfo")
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
|
||||
@@ -581,20 +579,6 @@ class MediaChain(ChainBase):
|
||||
logger.info(f"已存在nfo文件:{nfo_path}")
|
||||
else:
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 处理目录内的文件
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 生成目录内图片文件
|
||||
if init_folder:
|
||||
# 图片
|
||||
image_dict = self.metadata_img(mediainfo=mediainfo)
|
||||
if image_dict:
|
||||
@@ -618,7 +602,7 @@ class MediaChain(ChainBase):
|
||||
should_scrape = True # 未知类型默认刮削
|
||||
|
||||
if should_scrape:
|
||||
image_path = filepath.with_name(image_name)
|
||||
image_path = filepath / image_name
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 流式下载图片并直接保存
|
||||
@@ -681,7 +665,11 @@ class MediaChain(ChainBase):
|
||||
if recursive:
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir" and not file.name.lower().startswith("season"):
|
||||
if (
|
||||
file.type == "dir"
|
||||
and file.name not in settings.RENAME_FORMAT_S0_NAMES
|
||||
and not file.name.lower().startswith("season")
|
||||
):
|
||||
# 电视剧不处理非季子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
@@ -691,11 +679,19 @@ class MediaChain(ChainBase):
|
||||
overwrite=overwrite)
|
||||
# 生成目录的nfo和图片
|
||||
if init_folder:
|
||||
# TODO 目前的刮削是假定电视剧目录结构符合:/剧集根目录/季目录/剧集文件
|
||||
# 其中季目录应符合`Season 数字`等明确的季命名,不能用季标题
|
||||
# 例如:/Torchwood (2006)/Miracle Day/Torchwood (2006) S04E01.mkv
|
||||
# 当刮削到`Miracle Day`目录时,会误判其为剧集根目录
|
||||
# 识别文件夹名称
|
||||
season_meta = MetaInfo(filepath.name)
|
||||
# 当前文件夹为Specials或者SPs时,设置为S0
|
||||
if filepath.name in settings.RENAME_FORMAT_S0_NAMES:
|
||||
season_meta.begin_season = 0
|
||||
elif season_meta.name and season_meta.begin_season is not None:
|
||||
# 当前目录含有非季目录的名称,但却有季信息(通常是被辅助识别词指定了)
|
||||
# 这种情况应该是剧集根目录,不能按季目录刮削,否则会导致`season_poster`的路径错误 详见issue#5373
|
||||
season_meta.begin_season = None
|
||||
if season_meta.begin_season is not None:
|
||||
# 检查季NFO开关
|
||||
if scraping_switchs.get('season_nfo', True):
|
||||
@@ -765,7 +761,8 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
logger.info(f"季图片刮削已关闭,跳过:{image_name}")
|
||||
# 判断当前目录是不是剧集根目录
|
||||
if not season_meta.season:
|
||||
elif season_meta.name:
|
||||
# 不含季信息(包括特别季)但含有名称的,可以认为是剧集根目录
|
||||
# 检查电视剧NFO开关
|
||||
if scraping_switchs.get('tv_nfo', True):
|
||||
# 是否已存在
|
||||
@@ -961,10 +958,10 @@ class MediaChain(ChainBase):
|
||||
year = None
|
||||
if tmdbinfo.get('release_date'):
|
||||
year = tmdbinfo['release_date'][:4]
|
||||
elif tmdbinfo.get('seasons') and season:
|
||||
elif tmdbinfo.get('seasons') and season is not None:
|
||||
for seainfo in tmdbinfo['seasons']:
|
||||
season_number = seainfo.get("season_number")
|
||||
if not season_number:
|
||||
if season_number is None:
|
||||
continue
|
||||
air_date = seainfo.get("air_date")
|
||||
if air_date and season_number == season:
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import settings
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.meta import MetaBase
|
||||
from app.db.user_oper import UserOper
|
||||
@@ -40,7 +40,7 @@ class MessageChain(ChainBase):
|
||||
# 用户会话信息 {userid: (session_id, last_time)}
|
||||
_user_sessions: Dict[Union[str, int], tuple] = {}
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 15
|
||||
_session_timeout_minutes: int = 30
|
||||
|
||||
@staticmethod
|
||||
def __get_noexits_info(
|
||||
@@ -164,19 +164,15 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
# 处理消息
|
||||
if text.startswith('CALLBACK:'):
|
||||
# 处理按钮回调(适配支持回调的渠道)
|
||||
# 处理按钮回调(适配支持回调的渠),优先级最高
|
||||
if ChannelCapabilityManager.supports_callbacks(channel):
|
||||
self._handle_callback(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username,
|
||||
original_message_id=original_message_id, original_chat_id=original_chat_id)
|
||||
else:
|
||||
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
|
||||
elif text.startswith('/ai') or text.startswith('/AI'):
|
||||
# AI智能体处理
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
elif text.startswith('/') and not text.lower().startswith('/ai'):
|
||||
# 执行特定命令命令(但不是/ai)
|
||||
self.eventmanager.send_event(
|
||||
EventType.CommandExcute,
|
||||
{
|
||||
@@ -186,265 +182,231 @@ class MessageChain(ChainBase):
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
elif text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
# 选择项目
|
||||
if not cache_data \
|
||||
or not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
# 选择的序号
|
||||
_choice = int(text) + _current_page * self._page_size - 1
|
||||
# 缓存类型
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 缓存列表
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
# 选择
|
||||
elif text.lower().startswith('/ai'):
|
||||
# 用户指定AI智能体消息响应
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
|
||||
# 普通消息,全局智能体响应
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
else:
|
||||
# 非智能体普通消息响应
|
||||
if text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
# 选择项目
|
||||
if not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
if cache_type in ["Search", "ReSearch"]:
|
||||
# 当前媒体信息
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
_current_media = mediainfo
|
||||
# 查询缺失的媒体信息
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=_current_media)
|
||||
if exist_flag and cache_type == "Search":
|
||||
# 媒体库中已存在
|
||||
# 选择的序号
|
||||
_choice = int(text) + _current_page * self._page_size - 1
|
||||
# 缓存类型
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 缓存列表
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
# 选择
|
||||
try:
|
||||
if cache_type in ["Search", "ReSearch"]:
|
||||
# 当前媒体信息
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
_current_media = mediainfo
|
||||
# 查询缺失的媒体信息
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=_current_media)
|
||||
if exist_flag and cache_type == "Search":
|
||||
# 媒体库中已存在
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"【{_current_media.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
|
||||
userid=userid))
|
||||
return
|
||||
elif exist_flag:
|
||||
# 没有缺失,但要全量重新搜索和下载
|
||||
no_exists = self.__get_noexits_info(_current_meta, _current_media)
|
||||
# 发送缺失的媒体信息
|
||||
messages = []
|
||||
if no_exists and cache_type == "Search":
|
||||
# 发送缺失消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
elif no_exists:
|
||||
# 发送总集数的消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季总 {no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
if messages:
|
||||
self.post_message(Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
|
||||
userid=userid))
|
||||
# 搜索种子,过滤掉不需要的剧集,以便选择
|
||||
logger.info(f"开始搜索 {mediainfo.title_year} ...")
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"【{_current_media.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
|
||||
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
|
||||
userid=userid))
|
||||
return
|
||||
elif exist_flag:
|
||||
# 没有缺失,但要全量重新搜索和下载
|
||||
no_exists = self.__get_noexits_info(_current_meta, _current_media)
|
||||
# 发送缺失的媒体信息
|
||||
messages = []
|
||||
if no_exists and cache_type == "Search":
|
||||
# 发送缺失消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
elif no_exists:
|
||||
# 发送总集数的消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季总 {no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
if messages:
|
||||
self.post_message(Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
|
||||
userid=userid))
|
||||
# 搜索种子,过滤掉不需要的剧集,以便选择
|
||||
logger.info(f"开始搜索 {mediainfo.title_year} ...")
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
|
||||
userid=userid))
|
||||
# 开始搜索
|
||||
contexts = SearchChain().process(mediainfo=mediainfo,
|
||||
no_exists=no_exists)
|
||||
if not contexts:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title}"
|
||||
f"{_current_meta.sea} 未搜索到需要的资源!",
|
||||
userid=userid))
|
||||
return
|
||||
# 搜索结果排序
|
||||
contexts = TorrentHelper().sort_torrents(contexts)
|
||||
try:
|
||||
# 判断是否设置自动下载
|
||||
auto_download_user = settings.AUTO_DOWNLOAD_USER
|
||||
# 匹配到自动下载用户
|
||||
if auto_download_user \
|
||||
and (auto_download_user == "all"
|
||||
or any(userid == user for user in auto_download_user.split(","))):
|
||||
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
|
||||
# 自动选择下载
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=contexts,
|
||||
userid=userid,
|
||||
username=username,
|
||||
no_exists=no_exists)
|
||||
else:
|
||||
# 更新缓存
|
||||
user_cache[userid] = {
|
||||
"type": "Torrent",
|
||||
"items": contexts
|
||||
}
|
||||
_current_page = 0
|
||||
# 保存缓存
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
# 删除原消息
|
||||
if (original_message_id and original_chat_id and
|
||||
ChannelCapabilityManager.supports_deletion(channel)):
|
||||
self.delete_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id
|
||||
)
|
||||
# 发送种子数据
|
||||
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=mediainfo.title,
|
||||
items=contexts[:self._page_size],
|
||||
userid=userid,
|
||||
total=len(contexts))
|
||||
finally:
|
||||
contexts.clear()
|
||||
del contexts
|
||||
elif cache_type in ["Subscribe", "ReSubscribe"]:
|
||||
# 订阅或洗版媒体
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
# 洗版标识
|
||||
best_version = False
|
||||
# 查询缺失的媒体信息
|
||||
if cache_type == "Subscribe":
|
||||
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=mediainfo)
|
||||
if exist_flag:
|
||||
# 开始搜索
|
||||
contexts = SearchChain().process(mediainfo=mediainfo,
|
||||
no_exists=no_exists)
|
||||
if not contexts:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"【{mediainfo.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
|
||||
title=f"{mediainfo.title}"
|
||||
f"{_current_meta.sea} 未搜索到需要的资源!",
|
||||
userid=userid))
|
||||
return
|
||||
else:
|
||||
best_version = True
|
||||
# 转换用户名
|
||||
mp_name = UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
# 添加订阅,状态为N
|
||||
SubscribeChain().add(title=mediainfo.title,
|
||||
year=mediainfo.year,
|
||||
mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
season=_current_meta.begin_season,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=mp_name or username,
|
||||
best_version=best_version)
|
||||
elif cache_type == "Torrent":
|
||||
if int(text) == 0:
|
||||
# 自动选择下载,强制下载模式
|
||||
self.__auto_download(channel=channel,
|
||||
# 搜索结果排序
|
||||
contexts = TorrentHelper().sort_torrents(contexts)
|
||||
try:
|
||||
# 判断是否设置自动下载
|
||||
auto_download_user = settings.AUTO_DOWNLOAD_USER
|
||||
# 匹配到自动下载用户
|
||||
if auto_download_user \
|
||||
and (auto_download_user == "all"
|
||||
or any(userid == user for user in auto_download_user.split(","))):
|
||||
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
|
||||
# 自动选择下载
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=contexts,
|
||||
userid=userid,
|
||||
username=username,
|
||||
no_exists=no_exists)
|
||||
else:
|
||||
# 更新缓存
|
||||
user_cache[userid] = {
|
||||
"type": "Torrent",
|
||||
"items": contexts
|
||||
}
|
||||
_current_page = 0
|
||||
# 保存缓存
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
# 删除原消息
|
||||
if (original_message_id and original_chat_id and
|
||||
ChannelCapabilityManager.supports_deletion(channel)):
|
||||
self.delete_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id
|
||||
)
|
||||
# 发送种子数据
|
||||
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=mediainfo.title,
|
||||
items=contexts[:self._page_size],
|
||||
userid=userid,
|
||||
total=len(contexts))
|
||||
finally:
|
||||
contexts.clear()
|
||||
del contexts
|
||||
elif cache_type in ["Subscribe", "ReSubscribe"]:
|
||||
# 订阅或洗版媒体
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
# 洗版标识
|
||||
best_version = False
|
||||
# 查询缺失的媒体信息
|
||||
if cache_type == "Subscribe":
|
||||
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=mediainfo)
|
||||
if exist_flag:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"【{mediainfo.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
|
||||
userid=userid))
|
||||
return
|
||||
else:
|
||||
best_version = True
|
||||
# 转换用户名
|
||||
mp_name = UserOper().get_name(
|
||||
**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
# 添加订阅,状态为N
|
||||
SubscribeChain().add(title=mediainfo.title,
|
||||
year=mediainfo.year,
|
||||
mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
season=_current_meta.begin_season,
|
||||
channel=channel,
|
||||
source=source,
|
||||
cache_list=cache_list,
|
||||
userid=userid,
|
||||
username=username)
|
||||
else:
|
||||
# 下载种子
|
||||
context: Context = cache_list[_choice]
|
||||
# 下载
|
||||
DownloadChain().download_single(context, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
username=mp_name or username,
|
||||
best_version=best_version)
|
||||
elif cache_type == "Torrent":
|
||||
if int(text) == 0:
|
||||
# 自动选择下载,强制下载模式
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=cache_list,
|
||||
userid=userid,
|
||||
username=username)
|
||||
else:
|
||||
# 下载种子
|
||||
context: Context = cache_list[_choice]
|
||||
# 下载
|
||||
DownloadChain().download_single(context, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
if _current_page == 0:
|
||||
# 第一页
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是第一页了!", userid=userid))
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
# 减一页
|
||||
_current_page -= 1
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
if _current_page == 0:
|
||||
start = 0
|
||||
end = self._page_size
|
||||
else:
|
||||
start = _current_page * self._page_size
|
||||
end = start + self._page_size
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
# 发送媒体数据
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
total = len(cache_list)
|
||||
# 加一页
|
||||
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
|
||||
if not cache_list:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
|
||||
return
|
||||
else:
|
||||
# 第一页
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是第一页了!", userid=userid))
|
||||
return
|
||||
# 减一页
|
||||
_current_page -= 1
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
try:
|
||||
# 加一页
|
||||
_current_page += 1
|
||||
if _current_page == 0:
|
||||
start = 0
|
||||
end = self._page_size
|
||||
else:
|
||||
start = _current_page * self._page_size
|
||||
end = start + self._page_size
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=total,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
@@ -452,93 +414,145 @@ class MessageChain(ChainBase):
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=total,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
else:
|
||||
# 搜索或订阅
|
||||
if text.startswith("订阅"):
|
||||
# 订阅
|
||||
content = re.sub(r"订阅[::\s]*", "", text)
|
||||
action = "Subscribe"
|
||||
elif text.startswith("洗版"):
|
||||
# 洗版
|
||||
content = re.sub(r"洗版[::\s]*", "", text)
|
||||
action = "ReSubscribe"
|
||||
elif text.startswith("搜索") or text.startswith("下载"):
|
||||
# 重新搜索/下载
|
||||
content = re.sub(r"(搜索|下载)[::\s]*", "", text)
|
||||
action = "ReSearch"
|
||||
elif text.startswith("#") \
|
||||
or re.search(r"^请[问帮你]", text) \
|
||||
or re.search(r"[??]$", text) \
|
||||
or StringUtils.count_words(text) > 10 \
|
||||
or text.find("继续") != -1:
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
elif StringUtils.is_link(text):
|
||||
# 链接
|
||||
content = text
|
||||
action = "Link"
|
||||
else:
|
||||
# 搜索
|
||||
content = text
|
||||
action = "Search"
|
||||
|
||||
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
|
||||
# 搜索
|
||||
meta, medias = MediaChain().search(content)
|
||||
# 识别
|
||||
if not meta.name:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
|
||||
return
|
||||
# 开始搜索
|
||||
if not medias:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid))
|
||||
return
|
||||
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
|
||||
try:
|
||||
# 记录当前状态
|
||||
_current_meta = meta
|
||||
# 保存缓存
|
||||
user_cache[userid] = {
|
||||
'type': action,
|
||||
'items': medias
|
||||
}
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
_current_page = 0
|
||||
_current_media = None
|
||||
# 发送媒体列表
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=meta.name,
|
||||
items=medias[:self._page_size],
|
||||
userid=userid, total=len(medias))
|
||||
finally:
|
||||
medias.clear()
|
||||
del medias
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
total = len(cache_list)
|
||||
# 加一页
|
||||
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
|
||||
if not cache_list:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
|
||||
return
|
||||
else:
|
||||
try:
|
||||
# 加一页
|
||||
_current_page += 1
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list,
|
||||
userid=userid,
|
||||
total=total,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
# 发送媒体数据
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list,
|
||||
userid=userid,
|
||||
total=total,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
else:
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(
|
||||
EventType.UserMessage,
|
||||
{
|
||||
"text": content,
|
||||
"userid": userid,
|
||||
"channel": channel,
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
# 搜索或订阅
|
||||
if text.startswith("订阅"):
|
||||
# 订阅
|
||||
content = re.sub(r"订阅[::\s]*", "", text)
|
||||
action = "Subscribe"
|
||||
elif text.startswith("洗版"):
|
||||
# 洗版
|
||||
content = re.sub(r"洗版[::\s]*", "", text)
|
||||
action = "ReSubscribe"
|
||||
elif text.startswith("搜索") or text.startswith("下载"):
|
||||
# 重新搜索/下载
|
||||
content = re.sub(r"(搜索|下载)[::\s]*", "", text)
|
||||
action = "ReSearch"
|
||||
elif text.startswith("#") \
|
||||
or re.search(r"^请[问帮你]", text) \
|
||||
or re.search(r"[??]$", text) \
|
||||
or StringUtils.count_words(text) > 10 \
|
||||
or text.find("继续") != -1:
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
elif StringUtils.is_link(text):
|
||||
# 链接
|
||||
content = text
|
||||
action = "Link"
|
||||
else:
|
||||
# 搜索
|
||||
content = text
|
||||
action = "Search"
|
||||
|
||||
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
|
||||
# 搜索
|
||||
meta, medias = MediaChain().search(content)
|
||||
# 识别
|
||||
if not meta.name:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
|
||||
return
|
||||
# 开始搜索
|
||||
if not medias:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!",
|
||||
userid=userid))
|
||||
return
|
||||
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
|
||||
try:
|
||||
# 记录当前状态
|
||||
_current_meta = meta
|
||||
# 保存缓存
|
||||
user_cache[userid] = {
|
||||
'type': action,
|
||||
'items': medias
|
||||
}
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
_current_page = 0
|
||||
_current_media = None
|
||||
# 发送媒体列表
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=meta.name,
|
||||
items=medias[:self._page_size],
|
||||
userid=userid, total=len(medias))
|
||||
finally:
|
||||
medias.clear()
|
||||
del medias
|
||||
else:
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(
|
||||
EventType.UserMessage,
|
||||
{
|
||||
"text": content,
|
||||
"userid": userid,
|
||||
"channel": channel,
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
finally:
|
||||
user_cache.clear()
|
||||
del user_cache
|
||||
@@ -828,42 +842,41 @@ class MessageChain(ChainBase):
|
||||
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _get_or_create_session_id(userid: Union[str, int]) -> str:
|
||||
def _get_or_create_session_id(self, userid: Union[str, int]) -> str:
|
||||
"""
|
||||
获取或创建会话ID
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, last_time = MessageChain._user_sessions[userid]
|
||||
|
||||
if userid in self._user_sessions:
|
||||
session_id, last_time = self._user_sessions[userid]
|
||||
|
||||
# 计算时间差
|
||||
time_diff = current_time - last_time
|
||||
|
||||
# 如果时间差小于等于15分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
|
||||
|
||||
# 如果时间差小于等于xx分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=self._session_timeout_minutes):
|
||||
# 更新最后使用时间
|
||||
MessageChain._user_sessions[userid] = (session_id, current_time)
|
||||
logger.info(f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
self._user_sessions[userid] = (session_id, current_time)
|
||||
logger.info(
|
||||
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
return session_id
|
||||
|
||||
|
||||
# 创建新的会话ID
|
||||
new_session_id = f"user_{userid}_{int(time.time())}"
|
||||
MessageChain._user_sessions[userid] = (new_session_id, current_time)
|
||||
self._user_sessions[userid] = (new_session_id, current_time)
|
||||
logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}")
|
||||
return new_session_id
|
||||
|
||||
@staticmethod
|
||||
def clear_user_session(userid: Union[str, int]) -> bool:
|
||||
def clear_user_session(self, userid: Union[str, int]) -> bool:
|
||||
"""
|
||||
清除指定用户的会话信息
|
||||
返回是否成功清除
|
||||
"""
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
if userid in self._user_sessions:
|
||||
session_id, _ = self._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
return True
|
||||
return False
|
||||
@@ -874,31 +887,23 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
# 获取并清除会话信息
|
||||
session_id = None
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
if userid in self._user_sessions:
|
||||
session_id, _ = self._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
if session_id:
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id,
|
||||
user_id=str(userid)
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
asyncio.run(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id,
|
||||
user_id=str(userid)
|
||||
)
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id,
|
||||
user_id=str(userid)
|
||||
),
|
||||
global_vars.loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"清除智能体会话记忆失败: {e}")
|
||||
|
||||
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -914,7 +919,7 @@ class MessageChain(ChainBase):
|
||||
))
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
@@ -930,19 +935,11 @@ class MessageChain(ChainBase):
|
||||
))
|
||||
return
|
||||
|
||||
# 检查LLM配置
|
||||
if not settings.LLM_API_KEY:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助未配置,请在系统设置中配置"
|
||||
))
|
||||
return
|
||||
|
||||
# 提取用户消息
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀
|
||||
if text.lower().startswith("/ai"):
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感)
|
||||
else:
|
||||
user_message = text.strip() # 按原消息处理
|
||||
if not user_message:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
@@ -955,34 +952,20 @@ class MessageChain(ChainBase):
|
||||
|
||||
# 生成或复用会话ID
|
||||
session_id = self._get_or_create_session_id(userid)
|
||||
|
||||
|
||||
# 在事件循环中处理
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 如果没有事件循环,创建新的
|
||||
asyncio.run(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
),
|
||||
global_vars.loop
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")
|
||||
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached, FileCache
|
||||
from app.core.cache import cached, fresh
|
||||
from app.core.config import settings, global_vars
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@@ -31,9 +27,11 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
# 推荐缓存区域
|
||||
recommend_cache_region = "recommend"
|
||||
|
||||
def refresh_recommend(self):
|
||||
def refresh_recommend(self, manual: bool = False):
|
||||
"""
|
||||
刷新推荐
|
||||
|
||||
:param manual: 手动触发
|
||||
"""
|
||||
logger.debug("Starting to refresh Recommend data.")
|
||||
|
||||
@@ -66,7 +64,9 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
if method in methods_finished:
|
||||
continue
|
||||
logger.debug(f"Fetch {method.__name__} data for page {page}.")
|
||||
data = method(page=page)
|
||||
# 手动触发的刷新,总是需要获取最新数据
|
||||
with fresh(manual):
|
||||
data = method(page=page)
|
||||
if not data:
|
||||
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
|
||||
methods_finished.add(method)
|
||||
@@ -94,7 +94,6 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
poster_path = data.get("poster_path")
|
||||
if poster_path:
|
||||
poster_url = poster_path.replace("original", "w500")
|
||||
logger.debug(f"Caching poster image: {poster_url}")
|
||||
self.__fetch_and_save_image(poster_url)
|
||||
|
||||
@staticmethod
|
||||
@@ -103,40 +102,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
请求并保存图片
|
||||
:param url: 图片路径
|
||||
"""
|
||||
# 生成缓存路径
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path("images") / sanitized_path
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 获取缓存后端,并设置缓存时间为全局配置的缓存天数
|
||||
cache_backend = FileCache(base=settings.CACHE_PATH,
|
||||
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
|
||||
|
||||
# 本地存在缓存图片,则直接跳过
|
||||
if cache_backend.get(cache_path.as_posix(), region="images"):
|
||||
logger.debug(f"Cache hit: Image already exists at {cache_path}")
|
||||
return
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if not referer else None
|
||||
response = RequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
|
||||
if not response:
|
||||
logger.debug(f"Empty response for URL: {url}")
|
||||
return
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
Image.open(io.BytesIO(response.content)).verify()
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
return
|
||||
|
||||
# 保存缓存
|
||||
cache_backend.set(cache_path.as_posix(), response.content, region="images")
|
||||
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
|
||||
ImageHelper().fetch_image(url=url)
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
|
||||
@@ -29,6 +29,7 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
|
||||
__result_temp_file = "__search_result__"
|
||||
__ai_result_temp_file = "__ai_search_result__"
|
||||
|
||||
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
@@ -48,7 +49,7 @@ class SearchChain(ChainBase):
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
return []
|
||||
no_exists = None
|
||||
if season:
|
||||
if season is not None:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
@@ -98,6 +99,18 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_ai_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次AI推荐结果
|
||||
"""
|
||||
return await self.async_load_cache(self.__ai_result_temp_file)
|
||||
|
||||
async def async_save_ai_results(self, results: List[Context]):
|
||||
"""
|
||||
异步保存AI推荐结果
|
||||
"""
|
||||
await self.async_save_cache(results, self.__ai_result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
|
||||
@@ -116,7 +129,7 @@ class SearchChain(ChainBase):
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
return []
|
||||
no_exists = None
|
||||
if season:
|
||||
if season is not None:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
@@ -168,7 +181,7 @@ class SearchChain(ChainBase):
|
||||
# 过滤剧集
|
||||
season_episodes = {sea: info.episodes
|
||||
for sea, info in no_exists[mediakey].items()}
|
||||
elif mediainfo.season:
|
||||
elif mediainfo.season is not None:
|
||||
# 豆瓣只搜索当前季
|
||||
season_episodes = {mediainfo.season: []}
|
||||
else:
|
||||
|
||||
@@ -44,6 +44,7 @@ class SiteChain(ChainBase):
|
||||
"star-space.net": self.__indexphp_test,
|
||||
"yemapt.org": self.__yema_test,
|
||||
"hddolby.com": self.__hddolby_test,
|
||||
"rousi.pro": self.__rousi_test,
|
||||
}
|
||||
|
||||
def refresh_userdata(self, site: dict = None) -> Optional[SiteUserData]:
|
||||
@@ -249,6 +250,32 @@ class SiteChain(ChainBase):
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __rousi_test(site: Site) -> Tuple[bool, str]:
|
||||
"""
|
||||
判断站点是否已经登陆:rousi
|
||||
"""
|
||||
url = f"https://{StringUtils.get_url_domain(site.url)}/api/v1/profile"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {site.apikey}",
|
||||
}
|
||||
res = RequestUtils(
|
||||
headers=headers,
|
||||
proxies=settings.PROXY if site.proxy else None,
|
||||
timeout=site.timeout or 15
|
||||
).get_res(url=url)
|
||||
if res is None:
|
||||
return False, "无法打开网站!"
|
||||
if res.status_code == 200:
|
||||
user_info = res.json()
|
||||
if user_info and user_info.get("code") == 0:
|
||||
return True, "连接成功"
|
||||
return False, "APIKEY已过期"
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __parse_favicon(url: str, cookie: str, ua: str) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
@@ -462,20 +489,18 @@ class SiteChain(ChainBase):
|
||||
logger.warn(f"站点 {domain} 索引器不存在!")
|
||||
return
|
||||
# 查询站点图标
|
||||
site_icon = siteoper.get_icon_by_domain(domain)
|
||||
if not site_icon or not site_icon.base64:
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
|
||||
@eventmanager.register(EventType.SiteUpdated)
|
||||
def clear_site_data(self, event: Event):
|
||||
|
||||
@@ -31,6 +31,12 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("generate_qrcode", storage=storage)
|
||||
|
||||
def generate_auth_url(self, storage: str) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
"""
|
||||
return self.run_module("generate_auth_url", storage=storage)
|
||||
|
||||
def check_login(self, storage: str, **kwargs) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
登录确认
|
||||
@@ -133,30 +139,41 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("support_transtype", storage=storage)
|
||||
|
||||
def is_bluray_folder(self, fileitem: Optional[schemas.FileItem]) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "BDMV"):
|
||||
return True
|
||||
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "CERTIFICATE"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def contains_bluray_subdirectories(fileitems: Optional[List[schemas.FileItem]]) -> bool:
|
||||
"""
|
||||
判断是否包含蓝光必备的文件夹
|
||||
"""
|
||||
required_files = {"BDMV", "CERTIFICATE"}
|
||||
return any(
|
||||
item.type == "dir" and item.name in required_files
|
||||
for item in fileitems or []
|
||||
)
|
||||
|
||||
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
|
||||
"""
|
||||
删除媒体文件,以及不含媒体文件的目录
|
||||
"""
|
||||
|
||||
def __is_bluray_dir(_fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
_dir_files = self.list_files(fileitem=_fileitem, recursion=False)
|
||||
if _dir_files:
|
||||
for _f in _dir_files:
|
||||
if _f.type == "dir" and _f.name in ["BDMV", "CERTIFICATE"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
|
||||
if len(fileitem_path.parts) <= 2:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 根目录或一级目录不允许删除")
|
||||
return False
|
||||
if fileitem.type == "dir":
|
||||
# 本身是目录
|
||||
if __is_bluray_dir(fileitem):
|
||||
if self.is_bluray_folder(fileitem):
|
||||
logger.warn(f"正在删除蓝光原盘目录:【{fileitem.storage}】{fileitem.path}")
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
|
||||
@@ -42,7 +42,7 @@ class SubscribeChain(ChainBase):
|
||||
_LOCK_TIMOUT = 3600 * 2
|
||||
|
||||
@staticmethod
|
||||
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
@@ -144,7 +144,7 @@ class SubscribeChain(ChainBase):
|
||||
metainfo.year = year
|
||||
if mtype:
|
||||
metainfo.type = mtype
|
||||
if season:
|
||||
if season is not None:
|
||||
metainfo.type = MediaType.TV
|
||||
metainfo.begin_season = season
|
||||
# 识别媒体信息
|
||||
@@ -158,7 +158,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
else:
|
||||
# 使用TMDBID识别
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
|
||||
@@ -169,12 +169,12 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
if mediainfo:
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
|
||||
# 使用名称识别兜底
|
||||
@@ -188,7 +188,7 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 总集数
|
||||
if mediainfo.type == MediaType.TV:
|
||||
if not season:
|
||||
if season is None:
|
||||
season = 1
|
||||
# 总集数
|
||||
if not kwargs.get('total_episode'):
|
||||
@@ -292,7 +292,7 @@ class SubscribeChain(ChainBase):
|
||||
"description": mediainfo.overview
|
||||
})
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
return sid, err_msg
|
||||
|
||||
async def async_add(self, title: str, year: str,
|
||||
mtype: MediaType = None,
|
||||
@@ -321,7 +321,7 @@ class SubscribeChain(ChainBase):
|
||||
metainfo.year = year
|
||||
if mtype:
|
||||
metainfo.type = mtype
|
||||
if season:
|
||||
if season is not None:
|
||||
metainfo.type = MediaType.TV
|
||||
metainfo.begin_season = season
|
||||
# 识别媒体信息
|
||||
@@ -351,7 +351,7 @@ class SubscribeChain(ChainBase):
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
|
||||
# 使用名称识别兜底
|
||||
@@ -365,7 +365,7 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 总集数
|
||||
if mediainfo.type == MediaType.TV:
|
||||
if not season:
|
||||
if season is None:
|
||||
season = 1
|
||||
# 总集数
|
||||
if not kwargs.get('total_episode'):
|
||||
@@ -469,7 +469,7 @@ class SubscribeChain(ChainBase):
|
||||
"description": mediainfo.overview
|
||||
})
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
return sid, err_msg
|
||||
|
||||
@staticmethod
|
||||
def exists(mediainfo: MediaInfo, meta: MetaBase = None):
|
||||
@@ -530,7 +530,7 @@ class SubscribeChain(ChainBase):
|
||||
# 生成元数据
|
||||
meta = MetaInfo(subscribe.name)
|
||||
meta.year = subscribe.year
|
||||
meta.begin_season = subscribe.season or None
|
||||
meta.begin_season = subscribe.season if subscribe.season is not None else None
|
||||
try:
|
||||
meta.type = MediaType(subscribe.type)
|
||||
except ValueError:
|
||||
@@ -949,7 +949,7 @@ class SubscribeChain(ChainBase):
|
||||
and torrent_mediainfo.douban_id != mediainfo.douban_id:
|
||||
continue
|
||||
logger.info(
|
||||
f'{mediainfo.title_year} 通过媒体信ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
f'{mediainfo.title_year} 通过媒体ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -1119,6 +1119,19 @@ class SubscribeChain(ChainBase):
|
||||
})
|
||||
logger.info(f'{subscribe.name} 订阅元数据更新完成')
|
||||
|
||||
def get_subscribe_by_source(self, source: str) -> Optional[Subscribe]:
|
||||
"""
|
||||
从来源获取订阅
|
||||
"""
|
||||
source_keyword = self.parse_subscribe_source_keyword(source)
|
||||
if not source_keyword:
|
||||
return None
|
||||
# 只保留需要的字段动态获取订阅
|
||||
valid_fields = {k: v for k, v in source_keyword.items()
|
||||
if k in ["type", "season", "tmdbid", "doubanid", "bangumiid"]}
|
||||
# 暂时不考虑订阅历史, 若有必要再添加
|
||||
return SubscribeOper().get_by(**valid_fields)
|
||||
|
||||
@staticmethod
|
||||
def follow():
|
||||
"""
|
||||
@@ -1635,7 +1648,7 @@ class SubscribeChain(ChainBase):
|
||||
info = schemas.SubscribeEpisodeInfo()
|
||||
info.title = episode.name
|
||||
info.description = episode.overview
|
||||
info.backdrop = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/w500${episode.still_path}"
|
||||
info.backdrop = settings.TMDB_IMAGE_URL(episode.still_path, "w500")
|
||||
episodes[episode.episode_number] = info
|
||||
elif subscribe.type == MediaType.TV.value:
|
||||
# 根据开始结束集计算集信息
|
||||
@@ -1655,7 +1668,7 @@ class SubscribeChain(ChainBase):
|
||||
if download_his:
|
||||
for his in download_his:
|
||||
# 查询下载文件
|
||||
files = downloadhis.get_files_by_hash(his.download_hash)
|
||||
files = downloadhis.get_files_by_hash(his.download_hash, state=1)
|
||||
if files:
|
||||
for file in files:
|
||||
# 识别文件名
|
||||
@@ -1828,8 +1841,9 @@ class SubscribeChain(ChainBase):
|
||||
def get_subscribe_source_keyword(subscribe: Subscribe) -> str:
|
||||
"""
|
||||
构造用于订阅来源的关键字字符串
|
||||
|
||||
:param subscribe: Subscribe 对象
|
||||
:return: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
:return str: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
"""
|
||||
source_keyword = {
|
||||
'id': subscribe.id,
|
||||
@@ -1844,3 +1858,24 @@ class SubscribeChain(ChainBase):
|
||||
'bangumiid': subscribe.bangumiid
|
||||
}
|
||||
return f"Subscribe|{json.dumps(source_keyword, ensure_ascii=False)}"
|
||||
|
||||
@staticmethod
|
||||
def parse_subscribe_source_keyword(source_keyword_str: str) -> Optional[dict]:
|
||||
"""
|
||||
解析订阅来源关键字字符串
|
||||
|
||||
:param source_keyword_str: 订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
:return Dict: 如果解析失败则返回None
|
||||
"""
|
||||
if not source_keyword_str or not source_keyword_str.startswith("Subscribe|"):
|
||||
return None
|
||||
|
||||
try:
|
||||
# 分割字符串获取JSON部分
|
||||
json_part = source_keyword_str.split("|", 1)[1]
|
||||
# 解析JSON字符串
|
||||
source_keyword = json.loads(json_part)
|
||||
return source_keyword
|
||||
except (IndexError, json.JSONDecodeError, TypeError) as e:
|
||||
logger.error(f"解析订阅来源关键字失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -265,6 +265,9 @@ class TorrentsChain(ChainBase):
|
||||
for torrent in torrents:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
if not torrent.enclosure:
|
||||
logger.warn(f"缺少种子链接,忽略处理: {torrent.title}")
|
||||
continue
|
||||
logger.info(f'处理资源:{torrent.title} ...')
|
||||
# 识别
|
||||
meta = MetaInfo(title=torrent.title, subtitle=torrent.description)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -52,7 +52,10 @@ class UserChain(ChainBase):
|
||||
success, user_or_message = self.password_authenticate(credentials=credentials)
|
||||
if success:
|
||||
# 如果用户启用了二次验证码,则进一步验证
|
||||
if not self._verify_mfa(user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
logger.info(f"用户 {username} 通过密码认证成功")
|
||||
return True, user_or_message
|
||||
@@ -63,7 +66,10 @@ class UserChain(ChainBase):
|
||||
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
|
||||
if aux_success:
|
||||
# 辅助认证成功后再验证二次验证码
|
||||
if not self._verify_mfa(aux_user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(aux_user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
return True, aux_user_or_message
|
||||
else:
|
||||
@@ -159,22 +165,46 @@ class UserChain(ChainBase):
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
|
||||
@staticmethod
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool:
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> Union[bool, str]:
|
||||
"""
|
||||
验证 MFA(二次验证码)
|
||||
检查用户是否启用了 OTP 或 PassKey,如果启用了任何一种,都需要提供验证
|
||||
|
||||
:param user: 用户对象
|
||||
:param mfa_code: 二次验证码
|
||||
:return: 如果验证成功返回 True,否则返回 False
|
||||
:param mfa_code: 二次验证码(如果提供了则验证OTP)
|
||||
:return:
|
||||
- 如果验证成功返回 True
|
||||
- 如果需要MFA但未提供,返回 "MFA_REQUIRED"
|
||||
- 如果MFA验证失败,返回 False
|
||||
"""
|
||||
if not user.is_otp:
|
||||
# 检查用户是否有PassKey
|
||||
from app.db.models.passkey import PassKey
|
||||
has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id))
|
||||
|
||||
# 如果用户既没有启用OTP也没有PassKey,直接通过
|
||||
if not user.is_otp and not has_passkey:
|
||||
return True
|
||||
|
||||
# 如果用户启用了OTP或PassKey,但没有提供验证码,需要进行二次验证
|
||||
if not mfa_code:
|
||||
logger.info(f"用户 {user.name} 缺少 MFA 认证码")
|
||||
return False
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码")
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
# 如果提供了验证码,且用户启用了 OTP,则验证 OTP
|
||||
if user.is_otp:
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
# OTP 验证成功
|
||||
return True
|
||||
|
||||
# 用户未启用 OTP,此时提供的 mfa_code 无效;如果启用了 PassKey,则仍需通过 PassKey 验证
|
||||
if has_passkey:
|
||||
logger.info(
|
||||
f"用户 {user.name} 未启用 OTP,但已启用 PassKey,提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证"
|
||||
)
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
return True
|
||||
|
||||
def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool:
|
||||
|
||||
@@ -27,8 +27,6 @@ DEFAULT_CACHE_SIZE = 1024
|
||||
# 默认缓存有效期
|
||||
DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
# 上下文变量来控制缓存行为
|
||||
_fresh = contextvars.ContextVar('fresh', default=False)
|
||||
|
||||
@@ -297,14 +295,14 @@ class AsyncCacheBackend(CacheBackend):
|
||||
"""
|
||||
获取所有缓存键,类似 dict.keys()(异步)
|
||||
"""
|
||||
async for key, _ in await self.items(region=region):
|
||||
async for key, _ in self.items(region=region):
|
||||
yield key
|
||||
|
||||
async def values(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
获取所有缓存值,类似 dict.values()(异步)
|
||||
"""
|
||||
async for _, value in await self.items(region=region):
|
||||
async for _, value in self.items(region=region):
|
||||
yield value
|
||||
|
||||
async def update(self, other: Dict[str, Any], region: Optional[str] = DEFAULT_CACHE_REGION,
|
||||
@@ -332,7 +330,7 @@ class AsyncCacheBackend(CacheBackend):
|
||||
弹出最后一个缓存项,类似 dict.popitem()(异步)
|
||||
"""
|
||||
items = []
|
||||
async for item in await self.items(region=region):
|
||||
async for item in self.items(region=region):
|
||||
items.append(item)
|
||||
if not items:
|
||||
raise KeyError("popitem(): cache is empty")
|
||||
@@ -364,6 +362,11 @@ class MemoryBackend(CacheBackend):
|
||||
基于 `cachetools.TTLCache` 实现的缓存后端
|
||||
"""
|
||||
|
||||
# 类变量 _region_caches 的互斥锁
|
||||
_lock = threading.Lock()
|
||||
# 存储各个 region 的缓存实例,region -> TTLCache
|
||||
_region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
|
||||
|
||||
def __init__(self, cache_type: Literal['ttl', 'lru'] = 'ttl',
|
||||
maxsize: Optional[int] = None, ttl: Optional[int] = None):
|
||||
"""
|
||||
@@ -376,8 +379,6 @@ class MemoryBackend(CacheBackend):
|
||||
self.cache_type = cache_type
|
||||
self.maxsize = maxsize or DEFAULT_CACHE_SIZE
|
||||
self.ttl = ttl or DEFAULT_CACHE_TTL
|
||||
# 存储各个 region 的缓存实例,region -> TTLCache
|
||||
self._region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
|
||||
|
||||
def __get_region_cache(self, region: str) -> Optional[Union[MemoryTTLCache, MemoryLRUCache]]:
|
||||
"""
|
||||
@@ -400,7 +401,7 @@ class MemoryBackend(CacheBackend):
|
||||
maxsize = kwargs.get("maxsize", self.maxsize)
|
||||
region = self.get_region(region)
|
||||
# 设置缓存值
|
||||
with lock:
|
||||
with self._lock:
|
||||
# 如果该 key 尚未有缓存实例,则创建一个新的 TTLCache 实例
|
||||
region_cache = self._region_caches.setdefault(
|
||||
region,
|
||||
@@ -445,7 +446,7 @@ class MemoryBackend(CacheBackend):
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return
|
||||
with lock:
|
||||
with self._lock:
|
||||
del region_cache[key]
|
||||
|
||||
def clear(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> None:
|
||||
@@ -458,13 +459,13 @@ class MemoryBackend(CacheBackend):
|
||||
# 清理指定缓存区
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache:
|
||||
with lock:
|
||||
with self._lock:
|
||||
region_cache.clear()
|
||||
logger.debug(f"Cleared cache for region: {region}")
|
||||
else:
|
||||
# 清除所有区域的缓存
|
||||
for region_cache in self._region_caches.values():
|
||||
with lock:
|
||||
with self._lock:
|
||||
region_cache.clear()
|
||||
logger.info("Cleared all cache")
|
||||
|
||||
@@ -480,7 +481,7 @@ class MemoryBackend(CacheBackend):
|
||||
yield from ()
|
||||
return
|
||||
# 使用锁保护迭代过程,避免在迭代时缓存被修改
|
||||
with lock:
|
||||
with self._lock:
|
||||
# 创建快照避免并发修改问题
|
||||
items_snapshot = list(region_cache.items())
|
||||
for item in items_snapshot:
|
||||
@@ -507,18 +508,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param maxsize: 缓存的最大条目数
|
||||
:param ttl: 默认缓存存活时间,单位秒
|
||||
"""
|
||||
self.cache_type = cache_type
|
||||
self.maxsize = maxsize or DEFAULT_CACHE_SIZE
|
||||
self.ttl = ttl or DEFAULT_CACHE_TTL
|
||||
# 存储各个 region 的缓存实例,region -> TTLCache
|
||||
self._region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
|
||||
|
||||
def __get_region_cache(self, region: str) -> Optional[Union[MemoryTTLCache, MemoryLRUCache]]:
|
||||
"""
|
||||
获取指定区域的缓存实例,如果不存在则返回 None
|
||||
"""
|
||||
region = self.get_region(region)
|
||||
return self._region_caches.get(region)
|
||||
self._backend = MemoryBackend(cache_type=cache_type, maxsize=maxsize, ttl=ttl)
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None,
|
||||
region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None:
|
||||
@@ -530,18 +520,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param ttl: 缓存的存活时间,不传入为永久缓存,单位秒
|
||||
:param region: 缓存的区
|
||||
"""
|
||||
ttl = ttl or self.ttl
|
||||
maxsize = kwargs.get("maxsize", self.maxsize)
|
||||
region = self.get_region(region)
|
||||
# 设置缓存值
|
||||
with lock:
|
||||
# 如果该 key 尚未有缓存实例,则创建一个新的 TTLCache 实例
|
||||
region_cache = self._region_caches.setdefault(
|
||||
region,
|
||||
MemoryTTLCache(maxsize=maxsize, ttl=ttl) if self.cache_type == 'ttl'
|
||||
else MemoryLRUCache(maxsize=maxsize)
|
||||
)
|
||||
region_cache[key] = value
|
||||
return self._backend.set(key=key, value=value, ttl=ttl, region=region, **kwargs)
|
||||
|
||||
async def exists(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> bool:
|
||||
"""
|
||||
@@ -551,10 +530,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param region: 缓存的区
|
||||
:return: 存在返回 True,否则返回 False
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return False
|
||||
return key in region_cache
|
||||
return self._backend.exists(key=key, region=region)
|
||||
|
||||
async def get(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> Any:
|
||||
"""
|
||||
@@ -564,10 +540,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param region: 缓存的区
|
||||
:return: 返回缓存的值,如果缓存不存在返回 None
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return None
|
||||
return region_cache.get(key)
|
||||
return self._backend.get(key=key, region=region)
|
||||
|
||||
async def delete(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION):
|
||||
"""
|
||||
@@ -576,11 +549,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param key: 缓存的键
|
||||
:param region: 缓存的区
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return
|
||||
with lock:
|
||||
del region_cache[key]
|
||||
return self._backend.delete(key=key, region=region)
|
||||
|
||||
async def clear(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> None:
|
||||
"""
|
||||
@@ -588,19 +557,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
|
||||
:param region: 缓存的区,为None时清空所有区缓存
|
||||
"""
|
||||
if region:
|
||||
# 清理指定缓存区
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache:
|
||||
with lock:
|
||||
region_cache.clear()
|
||||
logger.debug(f"Cleared cache for region: {region}")
|
||||
else:
|
||||
# 清除所有区域的缓存
|
||||
for region_cache in self._region_caches.values():
|
||||
with lock:
|
||||
region_cache.clear()
|
||||
logger.info("All cache cleared!")
|
||||
return self._backend.clear(region=region)
|
||||
|
||||
async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]:
|
||||
"""
|
||||
@@ -609,14 +566,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param region: 缓存的区
|
||||
:return: 返回一个字典,包含所有缓存键值对
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return
|
||||
# 使用锁保护迭代过程,避免在迭代时缓存被修改
|
||||
with lock:
|
||||
# 创建快照避免并发修改问题
|
||||
items_snapshot = list(region_cache.items())
|
||||
for item in items_snapshot:
|
||||
for item in self._backend.items(region):
|
||||
yield item
|
||||
|
||||
async def close(self) -> None:
|
||||
@@ -1024,13 +974,11 @@ def fresh(fresh: bool = True):
|
||||
with fresh():
|
||||
result = some_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting fresh mode to {fresh}. {id(token):#x}")
|
||||
token = _fresh.set(fresh or is_fresh())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset fresh mode. {id(token):#x}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_fresh(fresh: bool = True):
|
||||
@@ -1041,13 +989,11 @@ async def async_fresh(fresh: bool = True):
|
||||
async with async_fresh():
|
||||
result = await some_async_cached_function()
|
||||
"""
|
||||
token = _fresh.set(fresh)
|
||||
logger.debug(f"Setting async_fresh mode to {fresh}. {id(token):#x}")
|
||||
token = _fresh.set(fresh or is_fresh())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_fresh.reset(token)
|
||||
logger.debug(f"Reset async_fresh mode. {id(token):#x}")
|
||||
|
||||
def is_fresh() -> bool:
|
||||
"""
|
||||
@@ -1119,15 +1065,16 @@ def AsyncCache(cache_type: Literal['ttl', 'lru'] = 'ttl',
|
||||
|
||||
|
||||
def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Optional[int] = None,
|
||||
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False):
|
||||
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False, shared_key: Optional[str] = None):
|
||||
"""
|
||||
自定义缓存装饰器,支持为每个 key 动态传递 maxsize 和 ttl
|
||||
|
||||
:param region: 缓存的区
|
||||
:param maxsize: 缓存的最大条目数
|
||||
:param region: 缓存区域的标识符,默认根据模块名、函数名等自动生成标识
|
||||
:param maxsize: 缓存区内的最大条目数
|
||||
:param ttl: 缓存的存活时间,单位秒,未传入则为永久缓存,单位秒
|
||||
:param skip_none: 跳过 None 缓存,默认为 True
|
||||
:param skip_empty: 跳过空值缓存(如 None, [], {}, "", set()),默认为 False
|
||||
:param shared_key: 同步/异步函数共享缓存的键,默认使用函数名(异步函数名会标准化为同步格式,如移除 `async_` 前缀)
|
||||
:return: 装饰器函数
|
||||
"""
|
||||
|
||||
@@ -1177,6 +1124,17 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
return False
|
||||
return True
|
||||
|
||||
def __standardize_func_name() -> str:
|
||||
"""
|
||||
将异步函数名标准化为同步函数的命名,以生成统一的缓存键
|
||||
"""
|
||||
# XXX 假设异步函数名与同步版本仅差`async_`前缀或`_async`后缀(当前MP代码大多符合),否则需通过`shared_key`参数显式指定
|
||||
return (
|
||||
func.__name__.removeprefix("async_").removesuffix("_async")
|
||||
if is_async
|
||||
else func.__name__
|
||||
)
|
||||
|
||||
def __get_cache_key(args, kwargs) -> str:
|
||||
"""
|
||||
根据函数和参数生成缓存键
|
||||
@@ -1198,13 +1156,22 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
bound.arguments[param] for param in signature.parameters if param in bound.arguments
|
||||
]
|
||||
# 使用有序参数生成缓存键
|
||||
return f"{func.__name__}_{hashkey(*keys)}"
|
||||
|
||||
# 获取缓存区
|
||||
cache_region = region if region is not None else f"{func.__module__}.{func.__name__}"
|
||||
return f"{func_name}_{hashkey(*keys)}"
|
||||
|
||||
# 被装饰函数的上层名称(如类名或外层函数名)
|
||||
enclosing_name = (
|
||||
func.__qualname__[:last_dot]
|
||||
if (last_dot := func.__qualname__.rfind(".")) != -1
|
||||
else ""
|
||||
)
|
||||
# 检查是否为异步函数
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
# 生成标准化后的函数名称,用于同步/异步函数共享缓存
|
||||
func_name = shared_key if shared_key else __standardize_func_name()
|
||||
# 获取缓存区
|
||||
cache_region = (
|
||||
region if region is not None else f"{func.__module__}:{enclosing_name}:{func_name}"
|
||||
)
|
||||
|
||||
if is_async:
|
||||
# 异步函数使用异步缓存后端
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
@@ -6,6 +7,7 @@ import re
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
from asyncio import AbstractEventLoop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
@@ -207,6 +209,8 @@ class ConfigModel(BaseModel):
|
||||
# ==================== 云盘配置 ====================
|
||||
# 115 AppId
|
||||
U115_APP_ID: str = "100196807"
|
||||
# 115 OAuth2 Server 地址
|
||||
U115_AUTH_SERVER: str = "https://movie-pilot.org"
|
||||
# Alipan AppId
|
||||
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
|
||||
|
||||
@@ -217,7 +221,7 @@ class ConfigModel(BaseModel):
|
||||
AUTO_UPDATE_RESOURCE: bool = True
|
||||
|
||||
# ==================== 媒体文件格式配置 ====================
|
||||
# 支持的后缀格式
|
||||
# 支持的视频文件后缀格式
|
||||
RMT_MEDIAEXT: list = Field(
|
||||
default_factory=lambda: ['.mp4', '.mkv', '.ts', '.iso',
|
||||
'.rmvb', '.avi', '.mov', '.mpeg',
|
||||
@@ -228,8 +232,6 @@ class ConfigModel(BaseModel):
|
||||
# 支持的字幕文件后缀格式
|
||||
RMT_SUBEXT: list = Field(default_factory=lambda: ['.srt', '.ass', '.ssa', '.sup'])
|
||||
# 支持的音轨文件后缀格式
|
||||
RMT_AUDIO_TRACK_EXT: list = Field(default_factory=lambda: ['.mka'])
|
||||
# 音轨文件后缀格式
|
||||
RMT_AUDIOEXT: list = Field(
|
||||
default_factory=lambda: ['.aac', '.ac3', '.amr', '.caf', '.cda', '.dsf',
|
||||
'.dff', '.kar', '.m4a', '.mp1', '.mp2', '.mp3',
|
||||
@@ -276,7 +278,7 @@ class ConfigModel(BaseModel):
|
||||
# 搜索多个名称
|
||||
SEARCH_MULTIPLE_NAME: bool = False
|
||||
# 最大搜索名称数量
|
||||
MAX_SEARCH_NAME_LIMIT: int = 2
|
||||
MAX_SEARCH_NAME_LIMIT: int = 3
|
||||
|
||||
# ==================== 下载配置 ====================
|
||||
# 种子标签
|
||||
@@ -303,6 +305,8 @@ class ConfigModel(BaseModel):
|
||||
COOKIECLOUD_BLACKLIST: Optional[str] = None
|
||||
|
||||
# ==================== 整理配置 ====================
|
||||
# 文件整理线程数
|
||||
TRANSFER_THREADS: int = 1
|
||||
# 电影重命名格式
|
||||
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
|
||||
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
|
||||
@@ -335,7 +339,7 @@ class ConfigModel(BaseModel):
|
||||
"https://github.com/thsrite/MoviePilot-Plugins,"
|
||||
"https://github.com/honue/MoviePilot-Plugins,"
|
||||
"https://github.com/InfinityPacer/MoviePilot-Plugins,"
|
||||
"https://github.com/DDS-Derek/MoviePilot-Plugins,"
|
||||
"https://github.com/DDSRem-Dev/MoviePilot-Plugins,"
|
||||
"https://github.com/madrays/MoviePilot-Plugins,"
|
||||
"https://github.com/justzerock/MoviePilot-Plugins,"
|
||||
"https://github.com/KoWming/MoviePilot-Plugins,"
|
||||
@@ -345,7 +349,12 @@ class ConfigModel(BaseModel):
|
||||
"https://github.com/Aqr-K/MoviePilot-Plugins,"
|
||||
"https://github.com/hotlcc/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/gxterry/MoviePilot-Plugins,"
|
||||
"https://github.com/DzAvril/MoviePilot-Plugins")
|
||||
"https://github.com/DzAvril/MoviePilot-Plugins,"
|
||||
"https://github.com/mrtian2016/MoviePilot-Plugins,"
|
||||
"https://github.com/Hqyel/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/xijin285/MoviePilot-Plugins,"
|
||||
"https://github.com/Seed680/MoviePilot-Plugins,"
|
||||
"https://github.com/imaliang/MoviePilot-Plugins")
|
||||
# 插件安装数据共享
|
||||
PLUGIN_STATISTIC_SHARE: bool = True
|
||||
# 是否开启插件热加载
|
||||
@@ -391,6 +400,10 @@ class ConfigModel(BaseModel):
|
||||
])
|
||||
# 允许的图片文件后缀格式
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
|
||||
# PassKey 是否强制用户验证(生物识别等)
|
||||
PASSKEY_REQUIRE_UV: bool = True
|
||||
# 允许在未启用 OTP 时直接注册 PassKey
|
||||
PASSKEY_ALLOW_REGISTER_WITHOUT_OTP: bool = False
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
# 工作流数据共享
|
||||
@@ -405,10 +418,14 @@ class ConfigModel(BaseModel):
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
# Playwright浏览器类型,chromium/firefox
|
||||
PLAYWRIGHT_BROWSER_TYPE: str = "chromium"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# 合局AI智能体
|
||||
AI_AGENT_GLOBAL: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
@@ -417,20 +434,32 @@ class ConfigModel(BaseModel):
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
|
||||
# LLM最大上下文Token数量(K)
|
||||
LLM_MAX_CONTEXT_TOKENS: int = 64
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.1
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 15
|
||||
LLM_MAX_ITERATIONS: int = 128
|
||||
# LLM工具调用超时时间(秒)
|
||||
LLM_TOOL_TIMEOUT: int = 300
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 50
|
||||
# 记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 30
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 30
|
||||
# 内存记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 1
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
# 是否启用AI推荐
|
||||
AI_RECOMMEND_ENABLED: bool = False
|
||||
# AI推荐用户偏好
|
||||
AI_RECOMMEND_USER_PREFERENCE: str = ""
|
||||
# Tavily API密钥(用于网络搜索)
|
||||
TAVILY_API_KEY: str = "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"
|
||||
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@@ -835,6 +864,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
rename_format = re.sub(r'/+', '/', rename_format)
|
||||
return rename_format.strip("/")
|
||||
|
||||
def TMDB_IMAGE_URL(
|
||||
self, file_path: Optional[str], file_size: str = "original"
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取TMDB图片网址
|
||||
|
||||
:param file_path: TMDB API返回的xxx_path
|
||||
:param file_size: 图片大小,例如:'original', 'w500' 等
|
||||
:return: 图片的完整URL,如果 file_path 为空则返回 None
|
||||
"""
|
||||
if not file_path:
|
||||
return None
|
||||
return (
|
||||
f"https://{self.TMDB_IMAGE_DOMAIN}/t/p/{file_size}/{file_path.removeprefix('/')}"
|
||||
)
|
||||
|
||||
|
||||
# 实例化配置
|
||||
settings = Settings()
|
||||
@@ -852,6 +897,8 @@ class GlobalVar(object):
|
||||
EMERGENCY_STOP_WORKFLOWS: List[int] = []
|
||||
# 需应急停止文件整理
|
||||
EMERGENCY_STOP_TRANSFER: List[str] = []
|
||||
# 当前事件循环
|
||||
CURRENT_EVENT_LOOP: AbstractEventLoop = asyncio.get_event_loop()
|
||||
|
||||
def stop_system(self):
|
||||
"""
|
||||
@@ -916,6 +963,19 @@ class GlobalVar(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def loop(self) -> AbstractEventLoop:
|
||||
"""
|
||||
当前循环
|
||||
"""
|
||||
return self.CURRENT_EVENT_LOOP
|
||||
|
||||
def set_loop(self, loop: AbstractEventLoop):
|
||||
"""
|
||||
设置循环
|
||||
"""
|
||||
self.CURRENT_EVENT_LOOP = loop
|
||||
|
||||
|
||||
# 全局标识
|
||||
global_vars = GlobalVar()
|
||||
|
||||
@@ -95,18 +95,20 @@ class TorrentInfo:
|
||||
if upload_volume_factor is None or download_volume_factor is None:
|
||||
return "未知"
|
||||
free_strs = {
|
||||
"1.0 1.0": "普通",
|
||||
"1.0 0.0": "免费",
|
||||
"2.0 1.0": "2X",
|
||||
"4.0 1.0": "4X",
|
||||
"2.0 0.0": "2X免费",
|
||||
"4.0 0.0": "4X免费",
|
||||
"1.0 0.5": "50%",
|
||||
"2.0 0.5": "2X 50%",
|
||||
"1.0 0.7": "70%",
|
||||
"1.0 0.3": "30%"
|
||||
"1.00 1.00": "普通",
|
||||
"1.00 0.00": "免费",
|
||||
"2.00 1.00": "2X",
|
||||
"4.00 1.00": "4X",
|
||||
"2.00 0.00": "2X免费",
|
||||
"4.00 0.00": "4X免费",
|
||||
"1.00 0.50": "50%",
|
||||
"2.00 0.50": "2X 50%",
|
||||
"1.00 0.70": "70%",
|
||||
"1.00 0.30": "30%",
|
||||
"1.00 0.75": "75%",
|
||||
"1.00 0.25": "25%"
|
||||
}
|
||||
return free_strs.get('%.1f %.1f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
return free_strs.get('%.2f %.2f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
|
||||
@property
|
||||
def volume_factor(self):
|
||||
@@ -463,7 +465,7 @@ class MediaInfo:
|
||||
for seainfo in info.get('seasons'):
|
||||
# 季
|
||||
season = seainfo.get("season_number")
|
||||
if not season:
|
||||
if season is None:
|
||||
continue
|
||||
# 集
|
||||
episode_count = seainfo.get("episode_count")
|
||||
@@ -477,11 +479,11 @@ class MediaInfo:
|
||||
self.episode_groups = info.pop("episode_groups").get("results") or []
|
||||
|
||||
# 海报
|
||||
if info.get('poster_path'):
|
||||
self.poster_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('poster_path')}"
|
||||
if path := info.get('poster_path'):
|
||||
self.poster_path = settings.TMDB_IMAGE_URL(path)
|
||||
# 背景
|
||||
if info.get('backdrop_path'):
|
||||
self.backdrop_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('backdrop_path')}"
|
||||
if path := info.get('backdrop_path'):
|
||||
self.backdrop_path = settings.TMDB_IMAGE_URL(path)
|
||||
# 导演和演员
|
||||
self.directors, self.actors = __directors_actors(info)
|
||||
# 别名和译名
|
||||
@@ -543,9 +545,9 @@ class MediaInfo:
|
||||
# 识别标题中的季
|
||||
meta = MetaInfo(info.get("title"))
|
||||
# 季
|
||||
if not self.season:
|
||||
if self.season is None:
|
||||
self.season = meta.begin_season
|
||||
if self.season:
|
||||
if self.season is not None:
|
||||
self.type = MediaType.TV
|
||||
elif not self.type:
|
||||
self.type = MediaType.MOVIE
|
||||
@@ -605,13 +607,13 @@ class MediaInfo:
|
||||
# 剧集
|
||||
if self.type == MediaType.TV and not self.seasons:
|
||||
meta = MetaInfo(info.get("title"))
|
||||
season = meta.begin_season or 1
|
||||
season = meta.begin_season if meta.begin_season is not None else 1
|
||||
episodes_count = info.get("episodes_count")
|
||||
if episodes_count:
|
||||
self.seasons[season] = list(range(1, episodes_count + 1))
|
||||
# 季年份
|
||||
if self.type == MediaType.TV and not self.season_years:
|
||||
season = self.season or 1
|
||||
season = self.season if self.season is not None else 1
|
||||
self.season_years = {
|
||||
season: self.year
|
||||
}
|
||||
@@ -665,7 +667,7 @@ class MediaInfo:
|
||||
# 识别标题中的季
|
||||
meta = MetaInfo(self.title)
|
||||
# 季
|
||||
if not self.season:
|
||||
if self.season is None:
|
||||
self.season = meta.begin_season
|
||||
# 评分
|
||||
if not self.vote_average:
|
||||
@@ -701,7 +703,7 @@ class MediaInfo:
|
||||
# 剧集
|
||||
if self.type == MediaType.TV and not self.seasons:
|
||||
meta = MetaInfo(self.title)
|
||||
season = meta.begin_season or 1
|
||||
season = meta.begin_season if meta.begin_season is not None else 1
|
||||
episodes_count = info.get("total_episodes")
|
||||
if episodes_count:
|
||||
self.seasons[season] = list(range(1, episodes_count + 1))
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.core.config import global_vars
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ChainEventData
|
||||
@@ -90,8 +91,6 @@ class EventManager(metaclass=Singleton):
|
||||
self.__lock = threading.Lock()
|
||||
# 退出事件
|
||||
self.__event = threading.Event()
|
||||
# 当前事件循环
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -454,7 +453,7 @@ class EventManager(metaclass=Singleton):
|
||||
# 对于异步函数,直接在事件循环中运行
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.__safe_invoke_handler_async(handler, isolated_event),
|
||||
self.loop
|
||||
global_vars.loop
|
||||
)
|
||||
else:
|
||||
# 对于同步函数,在线程池中运行
|
||||
|
||||
@@ -535,7 +535,7 @@ class MetaBase(object):
|
||||
|
||||
def merge(self, meta: Self):
|
||||
"""
|
||||
全并Meta信息
|
||||
合并Meta信息
|
||||
"""
|
||||
# 类型
|
||||
if self.type == MediaType.UNKNOWN \
|
||||
|
||||
@@ -301,7 +301,8 @@ class MetaVideo(MetaBase):
|
||||
return
|
||||
else:
|
||||
# 后缀名不要
|
||||
if ".%s".lower() % token in settings.RMT_MEDIAEXT:
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if ".%s".lower() % token in media_exts:
|
||||
return
|
||||
# 英文或者英文+数字,拼装起来
|
||||
if self.en_name:
|
||||
|
||||
@@ -25,7 +25,8 @@ def MetaInfo(title: str, subtitle: Optional[str] = None, custom_words: List[str]
|
||||
# 获取标题中媒体信息
|
||||
title, metainfo = find_metainfo(title)
|
||||
# 判断是否处理文件
|
||||
if title and Path(title).suffix.lower() in settings.RMT_MEDIAEXT:
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if title and Path(title).suffix.lower() in media_exts:
|
||||
isfile = True
|
||||
# 去掉后缀
|
||||
title = Path(title).stem
|
||||
@@ -62,21 +63,24 @@ def MetaInfo(title: str, subtitle: Optional[str] = None, custom_words: List[str]
|
||||
return meta
|
||||
|
||||
|
||||
def MetaInfoPath(path: Path) -> MetaBase:
|
||||
def MetaInfoPath(path: Path, custom_words: List[str] = None) -> MetaBase:
|
||||
"""
|
||||
根据路径识别元数据
|
||||
:param path: 路径
|
||||
:param custom_words: 自定义识别词列表
|
||||
"""
|
||||
# 文件元数据,不包含后缀
|
||||
file_meta = MetaInfo(title=path.name)
|
||||
file_meta = MetaInfo(title=path.name, custom_words=custom_words)
|
||||
# 上级目录元数据
|
||||
dir_meta = MetaInfo(title=path.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
dir_meta = MetaInfo(title=path.parent.name, custom_words=custom_words)
|
||||
if file_meta.type == MediaType.TV or dir_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
# 上上级目录元数据
|
||||
root_meta = MetaInfo(title=path.parent.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
root_meta = MetaInfo(title=path.parent.parent.name, custom_words=custom_words)
|
||||
if file_meta.type == MediaType.TV or root_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
return file_meta
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -20,7 +20,7 @@ from watchfiles import watch
|
||||
from app import schemas
|
||||
from app.core.cache import fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.plugin import PluginHelper
|
||||
@@ -28,16 +28,16 @@ from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.crypto import RSAUtils
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
插件管理器
|
||||
"""
|
||||
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""插件管理器"""
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
|
||||
|
||||
def __init__(self):
|
||||
# 插件列表
|
||||
@@ -250,20 +250,12 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
return self._plugins
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']:
|
||||
return
|
||||
logger.info("配置变更,重新加载插件文件修改监测...")
|
||||
def on_config_changed(self):
|
||||
self.reload_monitor()
|
||||
|
||||
def get_reload_name(self) -> str:
|
||||
return "插件文件修改监测"
|
||||
|
||||
def reload_monitor(self):
|
||||
"""
|
||||
重新加载插件文件修改监测
|
||||
|
||||
@@ -17,6 +17,7 @@ from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, AP
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
@@ -24,7 +25,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
# OAuth2PasswordBearer 用于 JWT Token 认证
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
oauth2_scheme_manual_error = OAuth2PasswordBearer(
|
||||
auto_error=False, # 禁用自动错误处理,用以支持API令牌鉴权
|
||||
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
|
||||
)
|
||||
|
||||
@@ -41,6 +43,58 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
|
||||
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
) -> str | None:
|
||||
"""
|
||||
从 URL 查询参数中获取 API Token
|
||||
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
|
||||
:return: 返回获取到的 API Token,若无则返回 None
|
||||
"""
|
||||
return token_query
|
||||
|
||||
|
||||
def __get_api_key(
|
||||
key_query: Annotated[str | None, Security(api_key_query)] = None,
|
||||
key_header: Annotated[str | None, Security(api_key_header)] = None
|
||||
) -> str | None:
|
||||
"""
|
||||
从 URL 查询参数或请求头部获取 API Key,优先使用请求头
|
||||
:param key_query: URL 中的 `apikey` 查询参数
|
||||
:param key_header: 请求头中的 `X-API-KEY` 参数
|
||||
:return: 返回从 URL 或请求头中获取的 API Key,若无则返回 None
|
||||
"""
|
||||
return key_header or key_query # 首选请求头
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl=600)
|
||||
def __create_superuser_token_payload() -> schemas.TokenPayload:
|
||||
"""
|
||||
创建管理员用户的TokenPayload
|
||||
|
||||
:return: 管理员TokenPayload
|
||||
"""
|
||||
# 延迟导入
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=no-name-in-module
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
|
||||
user = UserOper().get_by_name(settings.SUPERUSER)
|
||||
if not user or not user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户权限不足",
|
||||
)
|
||||
return schemas.TokenPayload(
|
||||
sub=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
level=SitesHelper().auth_level,
|
||||
purpose="authentication",
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
userid: Union[str, Any],
|
||||
username: str,
|
||||
@@ -176,23 +230,43 @@ def __verify_token(token: str, purpose: Optional[str] = "authentication") -> sch
|
||||
def verify_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
token: Annotated[str, Security(oauth2_scheme)]
|
||||
jwt_token: Annotated[str | None, Security(oauth2_scheme_manual_error)],
|
||||
api_key: Annotated[str | None, Security(__get_api_key)],
|
||||
api_token: Annotated[str | None, Security(__get_api_token)],
|
||||
) -> schemas.TokenPayload:
|
||||
"""
|
||||
验证 JWT 令牌并自动处理 resource_token 写入
|
||||
|
||||
如果缺少JWT令牌再尝试用API令牌鉴权
|
||||
|
||||
:param request: 请求对象,用于访问 Cookie 和请求信息
|
||||
:param response: 响应对象,用于设置 Cookie
|
||||
:param token: 从 Authorization 头部获取的 JWT 令牌
|
||||
:param jwt_token: 从 Authorization 头部获取的 JWT 令牌
|
||||
:param api_key: 从 查询参数`apikey` 或 请求头`X-API-KEY` 获取 API Token
|
||||
:param api_token: 从 查询参数`token` 获取 API Token
|
||||
:return: 解析后的 TokenPayload
|
||||
:raises HTTPException: 如果令牌无效或用途不匹配
|
||||
"""
|
||||
# 验证并解析 JWT 认证令牌
|
||||
payload = __verify_token(token=token, purpose="authentication")
|
||||
if jwt_token:
|
||||
# 验证并解析 JWT 认证令牌
|
||||
payload = __verify_token(token=jwt_token, purpose="authentication")
|
||||
|
||||
# 如果没有 resource_token,生成并写入到 Cookie
|
||||
__set_or_refresh_resource_token_cookie(request, response, payload)
|
||||
# 如果没有 resource_token,生成并写入到 Cookie
|
||||
__set_or_refresh_resource_token_cookie(request, response, payload)
|
||||
|
||||
return payload
|
||||
return payload
|
||||
elif api_key:
|
||||
verify_apikey(api_key)
|
||||
return __create_superuser_token_payload()
|
||||
elif api_token:
|
||||
verify_apitoken(api_token)
|
||||
return __create_superuser_token_payload()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def verify_resource_token(
|
||||
@@ -208,31 +282,7 @@ def verify_resource_token(
|
||||
return __verify_token(token=resource_token, purpose="resource")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
) -> str:
|
||||
"""
|
||||
从 URL 查询参数中获取 API Token
|
||||
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
|
||||
:return: 返回获取到的 API Token,若无则返回 None
|
||||
"""
|
||||
return token_query
|
||||
|
||||
|
||||
def __get_api_key(
|
||||
key_query: Annotated[str | None, Security(api_key_query)] = None,
|
||||
key_header: Annotated[str | None, Security(api_key_header)] = None
|
||||
) -> str:
|
||||
"""
|
||||
从 URL 查询参数或请求头部获取 API Key,优先使用 URL 参数
|
||||
:param key_query: URL 中的 `apikey` 查询参数
|
||||
:param key_header: 请求头中的 `X-API-KEY` 参数
|
||||
:return: 返回从 URL 或请求头中获取的 API Key,若无则返回 None
|
||||
"""
|
||||
return key_query or key_header
|
||||
|
||||
|
||||
def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
def __verify_key(key: str | None, expected_key: str, key_type: str) -> str:
|
||||
"""
|
||||
通用的 API Key 或 Token 验证函数
|
||||
:param key: 从请求中获取的 API Key 或 Token
|
||||
@@ -241,7 +291,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
:return: 返回校验通过的 API Key 或 Token
|
||||
:raises HTTPException: 如果校验不通过,抛出 401 错误
|
||||
"""
|
||||
if key != expected_key:
|
||||
if not key or key != expected_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"{key_type} 校验不通过"
|
||||
@@ -249,7 +299,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
def verify_apitoken(token: Annotated[str | None, Security(__get_api_token)]) -> str:
|
||||
"""
|
||||
使用 API Token 进行身份认证
|
||||
:param token: API Token,从 URL 查询参数中获取 token=xxx
|
||||
@@ -258,10 +308,10 @@ def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
return __verify_key(token, settings.API_TOKEN, "token")
|
||||
|
||||
|
||||
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
|
||||
def verify_apikey(apikey: Annotated[str | None, Security(__get_api_key)]) -> str:
|
||||
"""
|
||||
使用 API Key 进行身份认证
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx,或请求头中获取 X-API-KEY=xxx
|
||||
:return: 返回校验通过的 API Key
|
||||
"""
|
||||
return __verify_key(apikey, settings.API_TOKEN, "apikey")
|
||||
|
||||
@@ -454,7 +454,6 @@ class Base:
|
||||
|
||||
@db_update
|
||||
def update(self, db: Session, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
@@ -462,7 +461,6 @@ class Base:
|
||||
|
||||
@async_db_update
|
||||
async def async_update(self, db: AsyncSession, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
|
||||
@@ -49,7 +49,7 @@ class MediaServerOper(DbOper):
|
||||
if not item:
|
||||
return None
|
||||
|
||||
if kwargs.get("season"):
|
||||
if kwargs.get("season") is not None:
|
||||
# 判断季是否存在
|
||||
if not item.seasoninfo:
|
||||
return None
|
||||
@@ -75,7 +75,7 @@ class MediaServerOper(DbOper):
|
||||
if not item:
|
||||
return None
|
||||
|
||||
if kwargs.get("season"):
|
||||
if kwargs.get("season") is not None:
|
||||
# 判断季是否存在
|
||||
if not item.seasoninfo:
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .downloadhistory import DownloadHistory, DownloadFiles
|
||||
from .mediaserver import MediaServerItem
|
||||
from .passkey import PassKey
|
||||
from .plugindata import PluginData
|
||||
from .site import Site
|
||||
from .siteicon import SiteIcon
|
||||
|
||||
@@ -55,6 +55,8 @@ class DownloadHistory(Base):
|
||||
media_category = Column(String)
|
||||
# 剧集组
|
||||
episode_group = Column(String)
|
||||
# 自定义识别词(用于整理时应用)
|
||||
custom_words = Column(String)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
@@ -102,14 +104,14 @@ class DownloadHistory(Base):
|
||||
# TMDBID + 类型
|
||||
if tmdbid and mtype:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season).order_by(
|
||||
@@ -122,14 +124,14 @@ class DownloadHistory(Base):
|
||||
# 标题 + 年份
|
||||
elif title and year:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season).order_by(
|
||||
@@ -207,7 +209,7 @@ class DownloadFiles(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
|
||||
if state:
|
||||
if state is not None:
|
||||
return db.query(cls).filter(cls.download_hash == download_hash,
|
||||
cls.state == state).all()
|
||||
else:
|
||||
|
||||
@@ -65,6 +65,14 @@ class MediaServerItem(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def exists_by_title(cls, db: Session, title: str, mtype: str, year: str):
|
||||
if not mtype and not year:
|
||||
return db.query(cls).filter(cls.title == title).first()
|
||||
elif not year:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype).first()
|
||||
elif not mtype:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == str(year)).first()
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype,
|
||||
cls.year == str(year)).first()
|
||||
@@ -85,7 +93,16 @@ class MediaServerItem(Base):
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_exists_by_title(cls, db: AsyncSession, title: str, mtype: str, year: str):
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
if not mtype and not year:
|
||||
result = await db.execute(select(cls).filter(cls.title == title))
|
||||
elif not year:
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype))
|
||||
elif not mtype:
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
cls.year == str(year)))
|
||||
else:
|
||||
result = await db.execute(select(cls).filter(cls.title == title,
|
||||
cls.item_type == mtype,
|
||||
cls.year == str(year)))
|
||||
return result.scalars().first()
|
||||
|
||||
126
app/db/models/passkey.py
Normal file
126
app/db/models/passkey.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text, select, ForeignKey
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
from app.db import Base, db_query, db_update, async_db_query, async_db_update, get_id_column
|
||||
|
||||
|
||||
class PassKey(Base):
|
||||
"""
|
||||
用户PassKey凭证表
|
||||
"""
|
||||
# ID
|
||||
id = get_id_column()
|
||||
# 用户ID
|
||||
user_id = Column(Integer, ForeignKey('user.id'), nullable=False, index=True)
|
||||
# 凭证ID (credential_id)
|
||||
credential_id = Column(String, nullable=False, unique=True, index=True)
|
||||
# 凭证公钥
|
||||
public_key = Column(Text, nullable=False)
|
||||
# 签名计数器
|
||||
sign_count = Column(Integer, default=0)
|
||||
# 凭证名称(用户自定义)
|
||||
name = Column(String, default="通行密钥")
|
||||
# AAGUID (Authenticator Attestation GUID)
|
||||
aaguid = Column(String, nullable=True)
|
||||
# 创建时间
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
# 最后使用时间
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
# 是否启用
|
||||
is_active = Column(Boolean, default=True)
|
||||
# 传输方式 (usb, nfc, ble, internal)
|
||||
transports = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_user_id(cls, db: Session, user_id: int):
|
||||
"""获取用户的所有PassKey"""
|
||||
return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_user_id(cls, db: AsyncSession, user_id: int):
|
||||
"""异步获取用户的所有PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_credential_id(cls, db: Session, credential_id: str):
|
||||
"""根据凭证ID获取PassKey"""
|
||||
return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str):
|
||||
"""异步根据凭证ID获取PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_id(cls, db: Session, passkey_id: int):
|
||||
"""根据ID获取PassKey"""
|
||||
return db.query(cls).filter(cls.id == passkey_id).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_id(cls, db: AsyncSession, passkey_id: int):
|
||||
"""异步根据ID获取PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.id == passkey_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete_by_id(cls, db: Session, passkey_id: int, user_id: int):
|
||||
"""删除指定用户的PassKey"""
|
||||
passkey = db.query(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
).first()
|
||||
if passkey:
|
||||
passkey.delete(db, passkey.id)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int):
|
||||
"""异步删除指定用户的PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
)
|
||||
)
|
||||
passkey = result.scalars().first()
|
||||
if passkey:
|
||||
await passkey.async_delete(db, passkey.id)
|
||||
return True
|
||||
return False
|
||||
|
||||
@db_update
|
||||
def update_last_used(self, db: Session, sign_count: int):
|
||||
"""更新最后使用时间和签名计数"""
|
||||
self.update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
|
||||
@async_db_update
|
||||
async def async_update_last_used(self, db: AsyncSession, sign_count: int):
|
||||
"""异步更新最后使用时间和签名计数"""
|
||||
await self.async_update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
@@ -93,7 +93,7 @@ class Subscribe(Base):
|
||||
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
|
||||
@@ -106,7 +106,7 @@ class Subscribe(Base):
|
||||
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
@@ -148,7 +148,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.name == title,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.name == title).first()
|
||||
@@ -156,7 +156,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.name == title, cls.season == season)
|
||||
)
|
||||
@@ -169,7 +169,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).all()
|
||||
else:
|
||||
@@ -178,7 +178,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
@@ -227,6 +227,66 @@ class Subscribe(Base):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by(cls, db: Session, type: str, season: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
# TMDBID
|
||||
if tmdbid:
|
||||
if season is not None:
|
||||
result = db.query(cls).filter(
|
||||
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
|
||||
)
|
||||
else:
|
||||
result = db.query(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
|
||||
# 豆瓣ID
|
||||
elif doubanid:
|
||||
result = db.query(cls).filter(cls.doubanid == doubanid, cls.type == type)
|
||||
# BangumiID
|
||||
elif bangumiid:
|
||||
result = db.query(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
|
||||
else:
|
||||
return None
|
||||
|
||||
return result.first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by(cls, db: AsyncSession, type: str, season: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
# TMDBID
|
||||
if tmdbid:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
|
||||
)
|
||||
# 豆瓣ID
|
||||
elif doubanid:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.doubanid == doubanid, cls.type == type)
|
||||
)
|
||||
# BangumiID
|
||||
elif bangumiid:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
@db_update
|
||||
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
|
||||
subscrbies = self.get_by_tmdbid(db, tmdbid, season)
|
||||
|
||||
@@ -99,7 +99,7 @@ class SubscribeHistory(Base):
|
||||
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
|
||||
@@ -112,7 +112,7 @@ class SubscribeHistory(Base):
|
||||
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
|
||||
@@ -266,14 +266,14 @@ class TransferHistory(Base):
|
||||
# TMDBID + 类型
|
||||
if tmdbid and mtype:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.type == mtype,
|
||||
cls.seasons == season,
|
||||
cls.episodes == episode,
|
||||
cls.dest == dest).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.type == mtype,
|
||||
cls.seasons == season).all()
|
||||
@@ -290,14 +290,14 @@ class TransferHistory(Base):
|
||||
# 标题 + 年份
|
||||
elif title and year:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == year,
|
||||
cls.seasons == season,
|
||||
cls.episodes == episode,
|
||||
cls.dest == dest).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == year,
|
||||
cls.seasons == season).all()
|
||||
@@ -312,7 +312,7 @@ class TransferHistory(Base):
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == year).all()
|
||||
# 类型 + 转移路径(emby webhook season无tmdbid场景)
|
||||
elif mtype and season and dest:
|
||||
elif mtype and season is not None and dest:
|
||||
# 电视剧某季
|
||||
return db.query(cls).filter(cls.type == mtype,
|
||||
cls.seasons == season,
|
||||
|
||||
@@ -71,6 +71,7 @@ class SubscribeOper(DbOper):
|
||||
"backdrop": mediainfo.get_backdrop_image(),
|
||||
"vote": mediainfo.vote_average,
|
||||
"description": mediainfo.overview,
|
||||
"search_imdbid": 1 if kwargs.get('search_imdbid') else 0,
|
||||
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
})
|
||||
if not subscribe:
|
||||
@@ -91,7 +92,7 @@ class SubscribeOper(DbOper):
|
||||
判断是否存在
|
||||
"""
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return True if Subscribe.exists(self._db, tmdbid=tmdbid, season=season) else False
|
||||
else:
|
||||
return True if Subscribe.exists(self._db, tmdbid=tmdbid) else False
|
||||
@@ -111,6 +112,20 @@ class SubscribeOper(DbOper):
|
||||
"""
|
||||
return await Subscribe.async_get(self._db, rid=sid)
|
||||
|
||||
def get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
return Subscribe.get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
|
||||
|
||||
async def async_get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
return await Subscribe.async_get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
|
||||
|
||||
def list(self, state: Optional[str] = None) -> List[Subscribe]:
|
||||
"""
|
||||
获取订阅列表
|
||||
@@ -180,7 +195,7 @@ class SubscribeOper(DbOper):
|
||||
判断是否存在订阅历史
|
||||
"""
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return True if SubscribeHistory.exists(self._db, tmdbid=tmdbid, season=season) else False
|
||||
else:
|
||||
return True if SubscribeHistory.exists(self._db, tmdbid=tmdbid) else False
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from app.db import DbOper
|
||||
@@ -17,6 +19,8 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
super().__init__()
|
||||
self.__SYSTEMCONF = {}
|
||||
self._rlock = threading.RLock()
|
||||
self._alock = asyncio.Lock()
|
||||
for item in SystemConfig.list(self._db):
|
||||
self.__SYSTEMCONF[item.key] = item.value
|
||||
|
||||
@@ -29,23 +33,24 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
with self._rlock:
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
conf.create(self._db)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
conf.create(self._db)
|
||||
return True
|
||||
|
||||
async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]:
|
||||
"""
|
||||
@@ -56,22 +61,32 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = await SystemConfig.async_get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
async with self._alock:
|
||||
conf = await SystemConfig.async_get_by_key(self._db, key)
|
||||
# 确定是否需要更新数据库
|
||||
needs_db_update = False
|
||||
if conf:
|
||||
if conf.value != value:
|
||||
needs_db_update = True
|
||||
else: # 记录不存在,总是需要创建/更新
|
||||
needs_db_update = True
|
||||
if not needs_db_update:
|
||||
# 即使数据库值相同,也要确保缓存同步
|
||||
with self._rlock:
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
return None
|
||||
# 执行数据库更新
|
||||
if conf:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
await conf.async_update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
await conf.async_create(self._db)
|
||||
await conf.async_delete(self._db, conf.id)
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
await conf.async_create(self._db)
|
||||
# 数据库更新成功后,再更新缓存
|
||||
with self._rlock:
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
return True
|
||||
|
||||
def get(self, key: Union[str, SystemConfigKey] = None) -> Any:
|
||||
@@ -82,15 +97,17 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
key = key.value
|
||||
if not key:
|
||||
return self.all()
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF.get(key))
|
||||
with self._rlock:
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF.get(key))
|
||||
|
||||
def all(self):
|
||||
"""
|
||||
获取所有系统设置
|
||||
"""
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF)
|
||||
with self._rlock:
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF)
|
||||
|
||||
def delete(self, key: Union[str, SystemConfigKey]) -> bool:
|
||||
"""
|
||||
@@ -98,10 +115,11 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 更新内存
|
||||
self.__SYSTEMCONF.pop(key, None)
|
||||
# 写入数据库
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
with self._rlock:
|
||||
# 更新内存
|
||||
self.__SYSTEMCONF.pop(key, None)
|
||||
# 写入数据库
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
|
||||
@@ -125,7 +125,7 @@ class TransferHistoryOper(DbOper):
|
||||
"""
|
||||
新增转移成功历史记录
|
||||
"""
|
||||
self.add_force(
|
||||
return self.add_force(
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.utils.http import RequestUtils, cookie_parse
|
||||
|
||||
|
||||
class PlaywrightHelper:
|
||||
def __init__(self, browser_type="chromium"):
|
||||
def __init__(self, browser_type=settings.PLAYWRIGHT_BROWSER_TYPE):
|
||||
self.browser_type = browser_type
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -19,41 +19,42 @@ class CookieHelper:
|
||||
"username": [
|
||||
'//input[@name="username"]',
|
||||
'//input[@id="form_item_username"]',
|
||||
'//input[@id="username"]'
|
||||
'//input[@id="username"]',
|
||||
],
|
||||
"password": [
|
||||
'//input[@name="password"]',
|
||||
'//input[@id="form_item_password"]',
|
||||
'//input[@id="password"]',
|
||||
'//input[@type="password"]'
|
||||
'//input[@type="password"]',
|
||||
],
|
||||
"captcha": [
|
||||
'//input[@name="imagestring"]',
|
||||
'//input[@name="captcha"]',
|
||||
'//input[@id="form_item_captcha"]',
|
||||
'//input[@placeholder="驗證碼"]'
|
||||
'//input[@placeholder="驗證碼"]',
|
||||
],
|
||||
"captcha_img": [
|
||||
'//img[@alt="captcha"]/@src',
|
||||
'//img[@alt="CAPTCHA"]/@src',
|
||||
'//img[@alt="SECURITY CODE"]/@src',
|
||||
'//img[@id="LAY-user-get-vercode"]/@src',
|
||||
'//img[contains(@src,"/api/getCaptcha")]/@src'
|
||||
'//img[contains(@src,"/api/getCaptcha")]/@src',
|
||||
],
|
||||
"submit": [
|
||||
'//input[@type="submit"]',
|
||||
'//button[@type="submit"]',
|
||||
'//button[@lay-filter="login"]',
|
||||
'//button[@lay-filter="formLogin"]',
|
||||
'//input[@type="button"][@value="登录"]'
|
||||
'//input[@type="button"][@value="登录"]',
|
||||
'//input[@id="submit-btn"]',
|
||||
],
|
||||
"error": [
|
||||
"//table[@class='main']//td[@class='text']/text()"
|
||||
"//table[@class='main']//td[@class='text']/text()",
|
||||
],
|
||||
"twostep": [
|
||||
'//input[@name="two_step_code"]',
|
||||
'//input[@name="2fa_secret"]',
|
||||
'//input[@name="otp"]'
|
||||
'//input[@name="otp"]',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
@@ -9,7 +9,7 @@ from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?\}\}", re.DOTALL)
|
||||
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?}}", re.DOTALL)
|
||||
|
||||
|
||||
class DirectoryHelper:
|
||||
@@ -51,7 +51,7 @@ class DirectoryHelper:
|
||||
"""
|
||||
return [d for d in self.get_library_dirs() if d.library_storage == "local"]
|
||||
|
||||
def get_dir(self, media: MediaInfo, include_unsorted: Optional[bool] = False,
|
||||
def get_dir(self, media: Optional[MediaInfo], include_unsorted: Optional[bool] = False,
|
||||
storage: Optional[str] = None, src_path: Path = None,
|
||||
target_storage: Optional[str] = None, dest_path: Path = None
|
||||
) -> Optional[schemas.TransferDirectoryConf]:
|
||||
@@ -64,11 +64,8 @@ class DirectoryHelper:
|
||||
:param src_path: 源目录,有值时直接匹配
|
||||
:param dest_path: 目标目录,有值时直接匹配
|
||||
"""
|
||||
# 处理类型
|
||||
if not media:
|
||||
return None
|
||||
# 电影/电视剧
|
||||
media_type = media.type.value
|
||||
media_type = media.type.value if media else None
|
||||
dirs = self.get_dirs()
|
||||
|
||||
# 如果存在源目录,并源目录为任一下载目录的子目录时,则进行源目录匹配,否则,允许源目录按同盘优先的逻辑匹配
|
||||
@@ -93,7 +90,7 @@ class DirectoryHelper:
|
||||
if dest_path and dest_path != Path(d.library_path):
|
||||
continue
|
||||
# 目录类型为全部的,符合条件
|
||||
if not d.media_type:
|
||||
if not media_type or not d.media_type:
|
||||
matched_dirs.append(d)
|
||||
continue
|
||||
# 目录类型相等,目录类别为全部,符合条件
|
||||
@@ -109,11 +106,27 @@ class DirectoryHelper:
|
||||
# 优先源目录同盘
|
||||
for matched_dir in matched_dirs:
|
||||
matched_path = Path(matched_dir.download_path)
|
||||
if SystemUtils.is_same_disk(matched_path, src_path):
|
||||
if self._is_same_source((src_path, storage or "local"), (matched_path, matched_dir.library_storage)):
|
||||
return matched_dir
|
||||
return matched_dirs[0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_same_source(src: Tuple[Path, str], tar: Tuple[Path, str]) -> bool:
|
||||
"""
|
||||
判断源目录和目标目录是否在同一存储盘
|
||||
|
||||
:param src: 源目录路径和存储类型
|
||||
:param tar: 目标目录路径和存储类型
|
||||
:return: 是否在同一存储盘
|
||||
"""
|
||||
src_path, src_storage = src
|
||||
tar_path, tar_storage = tar
|
||||
if "local" == tar_storage == src_storage:
|
||||
return SystemUtils.is_same_disk(src_path, tar_path)
|
||||
# 网络存储,直接比较类型
|
||||
return src_storage == tar_storage
|
||||
|
||||
@staticmethod
|
||||
def get_media_root_path(rename_format: str, rename_path: Path) -> Optional[Path]:
|
||||
"""
|
||||
@@ -129,19 +142,22 @@ class DirectoryHelper:
|
||||
# 计算重命名中的文件夹层数
|
||||
rename_list = rename_format.split("/")
|
||||
rename_format_level = len(rename_list) - 1
|
||||
# 查找标题参数所在层
|
||||
for level, name in enumerate(rename_list):
|
||||
# 反向查找标题参数所在层
|
||||
for level, name in enumerate(reversed(rename_list)):
|
||||
if level == 0:
|
||||
# 跳过文件名的标题参数
|
||||
continue
|
||||
matchs = JINJA2_VAR_PATTERN.findall(name)
|
||||
if not matchs:
|
||||
continue
|
||||
# 处理特例,有的人重命名的第一层是年份、分辨率
|
||||
if any("title" in m for m in matchs):
|
||||
# 找出含标题的这一层作为媒体根目录
|
||||
rename_format_level -= level
|
||||
# 找出最后一层含有标题参数的目录作为媒体根目录
|
||||
rename_format_level = level
|
||||
break
|
||||
else:
|
||||
# 假定第一层目录是媒体根目录
|
||||
logger.warn(f"重命名格式 {rename_format} 缺少标题参数")
|
||||
logger.warn(f"重命名格式 {rename_format} 缺少标题目录")
|
||||
if rename_format_level > len(rename_path.parents):
|
||||
# 通常因为路径以/结尾,被Path规范化删除了
|
||||
logger.error(f"路径 {rename_path} 不匹配重命名格式 {rename_format}")
|
||||
|
||||
@@ -14,10 +14,8 @@ from threading import Lock
|
||||
from typing import Dict, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 定义一个全局线程池执行器
|
||||
@@ -69,25 +67,23 @@ def enable_doh(enable: bool):
|
||||
socket.getaddrinfo = _orig_getaddrinfo
|
||||
|
||||
|
||||
class DohHelper(metaclass=Singleton):
|
||||
class DohHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
DoH帮助类,用于处理DNS over HTTPS解析。
|
||||
"""
|
||||
CONFIG_WATCH = {"DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"}
|
||||
|
||||
def __init__(self):
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ["DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"]:
|
||||
return
|
||||
def on_config_changed(self):
|
||||
with _doh_lock:
|
||||
# DOH配置有变动的情况下,清空缓存
|
||||
_doh_cache.clear()
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
def get_reload_name(self):
|
||||
return 'DoH'
|
||||
|
||||
def _doh_query(resolver: str, host: str) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -25,7 +25,7 @@ class DownloaderHelper(ServiceBaseHelper[DownloaderConf]):
|
||||
) -> bool:
|
||||
"""
|
||||
通用的下载器类型判断方法
|
||||
:param service_type: 下载器的类型名称(如 'qbittorrent', 'transmission')
|
||||
:param service_type: 下载器的类型名称(如 'qbittorrent', 'transmission', 'rtorrent')
|
||||
:param service: 要判断的服务信息
|
||||
:param name: 服务的名称
|
||||
:return: 如果服务类型或实例为指定类型,返回 True;否则返回 False
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import cached, FileCache, AsyncFileCache
|
||||
from app.core.config import settings
|
||||
from app.utils.http import RequestUtils
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.ip import IpUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@@ -161,3 +168,121 @@ class WallpaperHelper(metaclass=Singleton):
|
||||
return wallpaper_list
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class ImageHelper(metaclass=Singleton):
|
||||
|
||||
def __init__(self):
|
||||
_base_path = settings.CACHE_PATH
|
||||
_ttl = settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600
|
||||
self.file_cache = FileCache(base=_base_path, ttl=_ttl)
|
||||
self.async_file_cache = AsyncFileCache(base=_base_path, ttl=_ttl)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_cache_path(url: str) -> str:
|
||||
"""缓存路径"""
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = Path(sanitized_path)
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
return cache_path.as_posix()
|
||||
|
||||
@staticmethod
|
||||
def _validate_image(content: bytes) -> bool:
|
||||
"""验证图片"""
|
||||
if not content:
|
||||
return False
|
||||
try:
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warn(f"Invalid image format: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_request_params(url: str, proxy: Optional[bool], cookies: Optional[str | dict]) -> dict:
|
||||
"""获取参数"""
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
if proxy is None:
|
||||
proxies = settings.PROXY if not (referer or IpUtils.is_internal(url)) else None
|
||||
else:
|
||||
proxies = settings.PROXY if proxy else None
|
||||
return {
|
||||
"ua": settings.NORMAL_USER_AGENT,
|
||||
"proxies": proxies,
|
||||
"referer": referer,
|
||||
"cookies": cookies,
|
||||
"accept_type": "image/avif,image/webp,image/apng,*/*",
|
||||
}
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = True,
|
||||
cookies: Optional[str | dict] = None) -> Optional[bytes]:
|
||||
"""
|
||||
获取图片(同步版本)
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
cache_path = self._prepare_cache_path(url)
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
content = self.file_cache.get(cache_path, region="images")
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = RequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
content = response.content
|
||||
# 验证图片
|
||||
if not self._validate_image(content):
|
||||
return None
|
||||
|
||||
# 保存缓存
|
||||
self.file_cache.set(cache_path, content, region="images")
|
||||
return content
|
||||
|
||||
async def async_fetch_image(
|
||||
self,
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = True,
|
||||
cookies: Optional[str | dict] = None) -> Optional[bytes]:
|
||||
"""
|
||||
获取图片(异步版本)
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
cache_path = self._prepare_cache_path(url)
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
content = await self.async_file_cache.get(cache_path, region="images")
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = await AsyncRequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
content = response.content
|
||||
# 验证图片
|
||||
if not self._validate_image(content):
|
||||
return None
|
||||
|
||||
# 保存缓存
|
||||
await self.async_file_cache.set(cache_path, content, region="images")
|
||||
return content
|
||||
108
app/helper/llm.py
Normal file
108
app/helper/llm.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
|
||||
"""
|
||||
获取LLM实例
|
||||
:param streaming: 是否启用流式输出
|
||||
:param callbacks: 回调处理器列表
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
else:
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
|
||||
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
if provider == "google":
|
||||
return self._get_google_models(api_key)
|
||||
else:
|
||||
return self._get_openai_compatible_models(provider, api_key, base_url)
|
||||
|
||||
@staticmethod
|
||||
def _get_google_models(api_key: str) -> List[str]:
|
||||
"""获取Google模型列表"""
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=api_key)
|
||||
models = genai.list_models()
|
||||
return [m.name for m in models if 'generateContent' in m.supported_generation_methods]
|
||||
except Exception as e:
|
||||
logger.error(f"获取Google模型列表失败:{e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
if provider == "deepseek":
|
||||
base_url = base_url or "https://api.deepseek.com"
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
models = client.models.list()
|
||||
return [model.id for model in models.data]
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {provider} 模型列表失败:{e}")
|
||||
raise e
|
||||
@@ -539,7 +539,7 @@ class MessageTemplateHelper:
|
||||
获取消息模板
|
||||
"""
|
||||
template_dict: dict[str, str] = SystemConfigOper().get(SystemConfigKey.NotificationTemplates)
|
||||
return template_dict.get(f"{message.ctype.value}")
|
||||
return template_dict.get(message.ctype.value)
|
||||
|
||||
|
||||
class MessageQueueManager(metaclass=SingletonClass):
|
||||
|
||||
361
app/helper/passkey.py
Normal file
361
app/helper/passkey.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
PassKey WebAuthn 辅助工具类
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import binascii
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from webauthn import (
|
||||
generate_registration_options,
|
||||
verify_registration_response,
|
||||
generate_authentication_options,
|
||||
verify_authentication_response,
|
||||
options_to_json
|
||||
)
|
||||
from webauthn.helpers import (
|
||||
parse_registration_credential_json,
|
||||
parse_authentication_credential_json
|
||||
)
|
||||
from webauthn.helpers.structs import (
|
||||
PublicKeyCredentialDescriptor,
|
||||
AuthenticatorTransport,
|
||||
UserVerificationRequirement,
|
||||
AuthenticatorAttachment,
|
||||
ResidentKeyRequirement,
|
||||
AuthenticatorSelectionCriteria
|
||||
)
|
||||
from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PassKeyHelper:
|
||||
"""
|
||||
PassKey WebAuthn 辅助类
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_rp_id() -> str:
|
||||
"""
|
||||
获取 Relying Party ID
|
||||
"""
|
||||
if settings.APP_DOMAIN:
|
||||
app_domain = settings.APP_DOMAIN.strip()
|
||||
# 确保存在协议前缀,以便 urlparse 正确解析主机和端口
|
||||
if not app_domain.startswith(('http://', 'https://')):
|
||||
app_domain = f'https://{app_domain}'
|
||||
parsed = urlparse(app_domain)
|
||||
host = parsed.hostname
|
||||
if host:
|
||||
return host
|
||||
# 从 APP_DOMAIN 中提取域名
|
||||
host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '')
|
||||
# 移除端口号
|
||||
if ':' in host:
|
||||
host = host.split(':')[0]
|
||||
return host
|
||||
# 只有在未配置 APP_DOMAIN 时,才默认为 localhost
|
||||
return 'localhost'
|
||||
|
||||
@staticmethod
|
||||
def get_rp_name() -> str:
|
||||
"""
|
||||
获取 Relying Party 名称
|
||||
"""
|
||||
return "MoviePilot"
|
||||
|
||||
@staticmethod
|
||||
def get_origin() -> str:
|
||||
"""
|
||||
获取源地址
|
||||
"""
|
||||
if settings.APP_DOMAIN:
|
||||
return settings.APP_DOMAIN.rstrip('/')
|
||||
# 如果未配置APP_DOMAIN,使用默认的localhost地址
|
||||
return f'http://localhost:{settings.NGINX_PORT}'
|
||||
|
||||
@staticmethod
|
||||
def standardize_credential_id(credential_id: str) -> str:
|
||||
"""
|
||||
标准化凭证ID(Base64 URL Safe)
|
||||
"""
|
||||
try:
|
||||
# Base64解码并重新编码以标准化格式
|
||||
decoded = base64.urlsafe_b64decode(credential_id + '==')
|
||||
return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=')
|
||||
except (binascii.Error, TypeError, ValueError) as e:
|
||||
logger.error(f"标准化凭证ID失败: {e}")
|
||||
return credential_id
|
||||
|
||||
@staticmethod
|
||||
def _base64_encode_urlsafe(data: bytes) -> str:
|
||||
"""
|
||||
Base64 URL Safe 编码(不带填充)
|
||||
|
||||
:param data: 要编码的字节数据
|
||||
:return: Base64 URL Safe 编码的字符串
|
||||
"""
|
||||
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
|
||||
|
||||
@staticmethod
|
||||
def _base64_decode_urlsafe(data: str) -> bytes:
|
||||
"""
|
||||
Base64 URL Safe 解码(自动添加填充)
|
||||
|
||||
:param data: Base64 URL Safe 编码的字符串
|
||||
:return: 解码后的字节数据
|
||||
"""
|
||||
return base64.urlsafe_b64decode(data + '==')
|
||||
|
||||
@staticmethod
|
||||
def _parse_credential_list(credentials: List[Dict[str, Any]]) -> List[PublicKeyCredentialDescriptor]:
|
||||
"""
|
||||
解析凭证列表为 PublicKeyCredentialDescriptor 列表
|
||||
|
||||
:param credentials: 凭证字典列表
|
||||
:return: PublicKeyCredentialDescriptor 列表
|
||||
"""
|
||||
result = []
|
||||
for cred in credentials:
|
||||
try:
|
||||
result.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=PassKeyHelper._base64_decode_urlsafe(cred['credential_id']),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_user_verification_requirement(user_verification: Optional[str] = None) -> UserVerificationRequirement:
|
||||
"""
|
||||
获取用户验证要求
|
||||
|
||||
:param user_verification: 指定的用户验证要求,如果不指定则从配置中读取
|
||||
:return: UserVerificationRequirement
|
||||
"""
|
||||
if user_verification:
|
||||
return UserVerificationRequirement(user_verification)
|
||||
return UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
|
||||
@staticmethod
|
||||
def _get_verification_params(
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
获取验证参数(origin 和 rp_id)
|
||||
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (origin, rp_id)
|
||||
"""
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
return origin, rp_id
|
||||
|
||||
@staticmethod
|
||||
def generate_registration_options(
|
||||
user_id: int,
|
||||
username: str,
|
||||
display_name: Optional[str] = None,
|
||||
existing_credentials: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
生成注册选项
|
||||
|
||||
:param user_id: 用户ID
|
||||
:param username: 用户名
|
||||
:param display_name: 显示名称
|
||||
:param existing_credentials: 已存在的凭证列表
|
||||
:return: (options_json, challenge)
|
||||
"""
|
||||
try:
|
||||
# 用户信息
|
||||
user_id_bytes = str(user_id).encode('utf-8')
|
||||
|
||||
# 排除已有的凭证
|
||||
exclude_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
|
||||
if existing_credentials else None
|
||||
|
||||
# 用户验证要求
|
||||
uv_requirement = PassKeyHelper._get_user_verification_requirement()
|
||||
|
||||
# 生成注册选项
|
||||
options = generate_registration_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
rp_name=PassKeyHelper.get_rp_name(),
|
||||
user_id=user_id_bytes,
|
||||
user_name=username,
|
||||
user_display_name=display_name or username,
|
||||
exclude_credentials=exclude_credentials,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
authenticator_attachment=None,
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
user_verification=uv_requirement,
|
||||
),
|
||||
supported_pub_key_algs=[
|
||||
COSEAlgorithmIdentifier.ECDSA_SHA_256,
|
||||
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
|
||||
]
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
# 提取challenge(用于后续验证)
|
||||
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成注册选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def verify_registration_response(
|
||||
credential: Dict[str, Any],
|
||||
expected_challenge: str,
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[str, str, int, Optional[str]]:
|
||||
"""
|
||||
验证注册响应
|
||||
|
||||
:param credential: 客户端返回的凭证
|
||||
:param expected_challenge: 期望的challenge
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (credential_id, public_key, sign_count, aaguid)
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
|
||||
# 解码challenge
|
||||
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
|
||||
|
||||
# 构建RegistrationCredential对象
|
||||
registration_credential = parse_registration_credential_json(json.dumps(credential))
|
||||
|
||||
# 验证注册响应
|
||||
verification = verify_registration_response(
|
||||
credential=registration_credential,
|
||||
expected_challenge=challenge_bytes,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=origin,
|
||||
require_user_verification=settings.PASSKEY_REQUIRE_UV
|
||||
)
|
||||
|
||||
# 提取信息
|
||||
credential_id = PassKeyHelper._base64_encode_urlsafe(verification.credential_id)
|
||||
public_key = PassKeyHelper._base64_encode_urlsafe(verification.credential_public_key)
|
||||
sign_count = verification.sign_count
|
||||
# aaguid 可能已经是字符串格式,也可能是bytes
|
||||
if verification.aaguid:
|
||||
if isinstance(verification.aaguid, bytes):
|
||||
aaguid = verification.aaguid.hex()
|
||||
else:
|
||||
aaguid = str(verification.aaguid)
|
||||
else:
|
||||
aaguid = None
|
||||
|
||||
return credential_id, public_key, sign_count, aaguid
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证注册响应失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def generate_authentication_options(
|
||||
existing_credentials: Optional[List[Dict[str, Any]]] = None,
|
||||
user_verification: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
生成认证选项
|
||||
|
||||
:param existing_credentials: 已存在的凭证列表(用于限制可用凭证)
|
||||
:param user_verification: 用户验证要求,如果不指定则从配置中读取
|
||||
:return: (options_json, challenge)
|
||||
"""
|
||||
try:
|
||||
# 允许的凭证
|
||||
allow_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
|
||||
if existing_credentials else None
|
||||
|
||||
# 用户验证要求
|
||||
uv_requirement = PassKeyHelper._get_user_verification_requirement(user_verification)
|
||||
|
||||
# 生成认证选项
|
||||
options = generate_authentication_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
allow_credentials=allow_credentials,
|
||||
user_verification=uv_requirement
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
# 提取challenge
|
||||
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成认证选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def verify_authentication_response(
|
||||
credential: Dict[str, Any],
|
||||
expected_challenge: str,
|
||||
credential_public_key: str,
|
||||
credential_current_sign_count: int,
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
验证认证响应
|
||||
|
||||
:param credential: 客户端返回的凭证
|
||||
:param expected_challenge: 期望的challenge
|
||||
:param credential_public_key: 凭证公钥
|
||||
:param credential_current_sign_count: 当前签名计数
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (验证成功, 新的签名计数)
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
|
||||
# 解码
|
||||
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
|
||||
public_key_bytes = PassKeyHelper._base64_decode_urlsafe(credential_public_key)
|
||||
|
||||
# 构建AuthenticationCredential对象
|
||||
authentication_credential = parse_authentication_credential_json(json.dumps(credential))
|
||||
|
||||
# 验证认证响应
|
||||
verification = verify_authentication_response(
|
||||
credential=authentication_credential,
|
||||
expected_challenge=challenge_bytes,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=origin,
|
||||
credential_public_key=public_key_bytes,
|
||||
credential_current_sign_count=credential_current_sign_count,
|
||||
require_user_verification=settings.PASSKEY_REQUIRE_UV
|
||||
)
|
||||
|
||||
return True, verification.new_sign_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证认证响应失败: {e}")
|
||||
return False, credential_current_sign_count
|
||||
@@ -7,10 +7,8 @@ import redis
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 类型缓存集合,针对非容器简单类型
|
||||
@@ -74,16 +72,17 @@ def deserialize(value: bytes) -> Any:
|
||||
raise ValueError("Unknown serialization format")
|
||||
|
||||
|
||||
class RedisHelper(metaclass=Singleton):
|
||||
class RedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
Redis连接和操作助手类,单例模式
|
||||
|
||||
|
||||
特性:
|
||||
- 管理Redis连接池和客户端
|
||||
- 提供序列化和反序列化功能
|
||||
- 支持内存限制和淘汰策略设置
|
||||
- 提供键名生成和区域管理功能
|
||||
"""
|
||||
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -114,25 +113,17 @@ class RedisHelper(metaclass=Singleton):
|
||||
self.client = None
|
||||
raise RuntimeError("Redis connection failed") from e
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新Redis设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
|
||||
return
|
||||
logger.info("配置变更,重连Redis...")
|
||||
def on_config_changed(self):
|
||||
self.close()
|
||||
self._connect()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "Redis"
|
||||
|
||||
def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
|
||||
"""
|
||||
动态设置Redis最大内存和内存淘汰策略
|
||||
|
||||
|
||||
:param policy: 淘汰策略(如'allkeys-lru')
|
||||
"""
|
||||
try:
|
||||
@@ -310,10 +301,10 @@ class RedisHelper(metaclass=Singleton):
|
||||
logger.debug("Redis connection closed")
|
||||
|
||||
|
||||
class AsyncRedisHelper(metaclass=Singleton):
|
||||
class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
异步Redis连接和操作助手类,单例模式
|
||||
|
||||
|
||||
特性:
|
||||
- 管理异步Redis连接池和客户端
|
||||
- 提供序列化和反序列化功能
|
||||
@@ -321,6 +312,7 @@ class AsyncRedisHelper(metaclass=Singleton):
|
||||
- 提供键名生成和区域管理功能
|
||||
- 所有操作都是异步的
|
||||
"""
|
||||
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -351,25 +343,17 @@ class AsyncRedisHelper(metaclass=Singleton):
|
||||
self.client = None
|
||||
raise RuntimeError("Redis async connection failed") from e
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
async def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新Redis设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
|
||||
return
|
||||
logger.info("配置变更,重连Redis (async)...")
|
||||
async def on_config_changed(self):
|
||||
await self.close()
|
||||
await self._connect()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "Redis (async)"
|
||||
|
||||
async def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
|
||||
"""
|
||||
动态设置Redis最大内存和内存淘汰策略
|
||||
|
||||
|
||||
:param policy: 淘汰策略(如'allkeys-lru')
|
||||
"""
|
||||
try:
|
||||
|
||||
@@ -382,7 +382,10 @@ class RssHelper:
|
||||
size = int(size_attr)
|
||||
|
||||
# 发布日期
|
||||
pubdate_nodes = item.xpath('.//pubDate | .//published | .//updated')
|
||||
pubdate_nodes = item.xpath('./pubDate | ./published | ./updated')
|
||||
if not pubdate_nodes:
|
||||
pubdate_nodes = item.xpath('.//*[local-name()="pubDate"] | .//*[local-name()="published"] | .//*[local-name()="updated"]')
|
||||
|
||||
pubdate = ""
|
||||
if pubdate_nodes and pubdate_nodes[0].text:
|
||||
pubdate = StringUtils.get_time(pubdate_nodes[0].text)
|
||||
|
||||
@@ -8,35 +8,32 @@ from typing import Tuple
|
||||
import docker
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.log import logger
|
||||
from app.schemas import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class SystemHelper:
|
||||
class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
系统工具类,提供系统相关的操作和判断
|
||||
"""
|
||||
CONFIG_WATCH = {
|
||||
"DEBUG",
|
||||
"LOG_LEVEL",
|
||||
"LOG_MAX_FILE_SIZE",
|
||||
"LOG_BACKUP_COUNT",
|
||||
"LOG_FILE_FORMAT",
|
||||
"LOG_CONSOLE_FORMAT",
|
||||
}
|
||||
|
||||
__system_flag_file = "/var/log/nginx/__moviepilot__"
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件,更新日志设置
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEBUG', 'LOG_LEVEL', 'LOG_MAX_FILE_SIZE', 'LOG_BACKUP_COUNT',
|
||||
'LOG_FILE_FORMAT', 'LOG_CONSOLE_FORMAT']:
|
||||
return
|
||||
logger.info("配置变更,更新日志设置...")
|
||||
def on_config_changed(self):
|
||||
logger.update_loggers()
|
||||
|
||||
def get_reload_name(self):
|
||||
return "日志设置"
|
||||
|
||||
@staticmethod
|
||||
def can_restart() -> bool:
|
||||
"""
|
||||
|
||||
@@ -6,8 +6,7 @@ from urllib.parse import unquote
|
||||
|
||||
from torrentool.api import Torrent
|
||||
|
||||
from app.core.cache import FileCache
|
||||
from app.core.cache import TTLCache
|
||||
from app.core.cache import TTLCache, FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, TorrentInfo, MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
@@ -26,7 +25,7 @@ class TorrentHelper:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._invalid_torrents = TTLCache(maxsize=128, ttl=3600 * 24)
|
||||
self._invalid_torrents = TTLCache(region="invalid_torrents", maxsize=128, ttl=3600 * 24)
|
||||
|
||||
def download_torrent(self, url: str,
|
||||
cookie: Optional[str] = None,
|
||||
@@ -341,11 +340,11 @@ class TorrentHelper:
|
||||
episodes = list(set(episodes).union(set(meta.episode_list)))
|
||||
return episodes
|
||||
|
||||
def is_invalid(self, url: str) -> bool:
|
||||
def is_invalid(self, url: Optional[str]) -> bool:
|
||||
"""
|
||||
判断种子是否是无效种子
|
||||
"""
|
||||
return url in self._invalid_torrents
|
||||
return url in self._invalid_torrents if url else True
|
||||
|
||||
def add_invalid(self, url: str):
|
||||
"""
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas import Notification, NotificationConf, MediaServerConf, DownloaderConf
|
||||
from app.schemas.types import ModuleType, DownloaderType, MediaServerType, MessageChannel, StorageSchema, \
|
||||
OtherModulesType
|
||||
OtherModulesType, SystemConfigKey
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
|
||||
|
||||
class _ModuleBase(metaclass=ABCMeta):
|
||||
class _ModuleBase(ConfigReloadMixin, metaclass=ABCMeta):
|
||||
"""
|
||||
模块基类,实现对应方法,在有需要时会被自动调用,返回None代表不启用该模块,将继续执行下一模块
|
||||
输入参数与输出参数一致的,或没有输出的,可以被多个模块重复实现
|
||||
"""
|
||||
|
||||
def on_config_changed(self):
|
||||
self.init_module()
|
||||
|
||||
def get_reload_name(self):
|
||||
return self.get_name()
|
||||
|
||||
@abstractmethod
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
@@ -177,6 +185,7 @@ class _MessageBase(ServiceBase[TService, NotificationConf]):
|
||||
"""
|
||||
消息基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Notifications.value}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -224,6 +233,7 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
"""
|
||||
下载器基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.Downloaders.value}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
@@ -281,12 +291,37 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
|
||||
重置默认配置名称
|
||||
"""
|
||||
self._default_config_name = None
|
||||
|
||||
def normalize_path(self, path: Path, downloader: Optional[str]) -> str:
|
||||
"""
|
||||
根据下载器配置和路径映射,规范化下载路径
|
||||
|
||||
:param path: 存储路径
|
||||
:param downloader: 下载器名称
|
||||
:return: 规范化后发送给下载器的路径
|
||||
"""
|
||||
dir = path.as_posix()
|
||||
conf = self.get_config(downloader)
|
||||
if conf and conf.path_mapping:
|
||||
for (storage_path, download_path) in conf.path_mapping:
|
||||
storage_path = Path(storage_path.strip()).as_posix()
|
||||
download_path = Path(download_path.strip()).as_posix()
|
||||
if dir.startswith(storage_path):
|
||||
dir = dir.replace(storage_path, download_path, 1)
|
||||
break
|
||||
# 去掉存储协议前缀 if any, 下载器无法识别
|
||||
for s in StorageSchema:
|
||||
prefix = f"{s.value}:"
|
||||
if dir.startswith(prefix):
|
||||
return dir[len(prefix):]
|
||||
return dir
|
||||
|
||||
|
||||
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
|
||||
"""
|
||||
媒体服务器基类
|
||||
"""
|
||||
CONFIG_WATCH = {SystemConfigKey.MediaServers.value}
|
||||
|
||||
def get_configs(self) -> Dict[str, MediaServerConf]:
|
||||
"""
|
||||
|
||||
@@ -290,3 +290,11 @@ class BangumiModule(_ModuleBase):
|
||||
if infos:
|
||||
return [MediaInfo(bangumi_info=info) for info in infos]
|
||||
return []
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清除缓存
|
||||
"""
|
||||
logger.info(f"开始清除{self.get_name()}缓存 ...")
|
||||
self.bangumiapi.clear_cache()
|
||||
logger.info(f"{self.get_name()}缓存清除完成")
|
||||
|
||||
@@ -31,7 +31,7 @@ class BangumiApi(object):
|
||||
self._req = RequestUtils(ua=settings.NORMAL_USER_AGENT, session=self._session)
|
||||
self._async_req = AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT)
|
||||
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta)
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta, shared_key="get")
|
||||
def __invoke(self, url, key: Optional[str] = None, **kwargs):
|
||||
req_url = self._base_url + url
|
||||
params = {}
|
||||
@@ -47,7 +47,7 @@ class BangumiApi(object):
|
||||
print(e)
|
||||
return None
|
||||
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta)
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta, shared_key="get")
|
||||
async def __async_invoke(self, url, key: Optional[str] = None, **kwargs):
|
||||
req_url = self._base_url + url
|
||||
params = {}
|
||||
@@ -300,6 +300,12 @@ class BangumiApi(object):
|
||||
key="data",
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d'), **kwargs)
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清除缓存
|
||||
"""
|
||||
self.__invoke.cache_clear()
|
||||
|
||||
def close(self):
|
||||
if self._session:
|
||||
self._session.close()
|
||||
|
||||
235
app/modules/discord/__init__.py
Normal file
235
app/modules/discord/__init__.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import json
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
try:
|
||||
from app.modules.discord.discord import Discord
|
||||
except Exception as err: # ImportError or other load issues
|
||||
Discord = None
|
||||
logger.error(f"Discord 模块未加载,缺少依赖或初始化错误:{err}")
|
||||
|
||||
|
||||
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
"""
|
||||
if not Discord:
|
||||
logger.error("Discord 依赖未就绪(需要安装 discord.py==2.6.4),模块未启动")
|
||||
return
|
||||
self.stop()
|
||||
super().init_service(service_name=Discord.__name__.lower(),
|
||||
service_type=Discord)
|
||||
self._channel = MessageChannel.Discord
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Discord"
|
||||
|
||||
@staticmethod
|
||||
def get_type() -> ModuleType:
|
||||
"""
|
||||
获取模块类型
|
||||
"""
|
||||
return ModuleType.Notification
|
||||
|
||||
@staticmethod
|
||||
def get_subtype() -> MessageChannel:
|
||||
"""
|
||||
获取模块子类型
|
||||
"""
|
||||
return MessageChannel.Discord
|
||||
|
||||
@staticmethod
|
||||
def get_priority() -> int:
|
||||
"""
|
||||
获取模块优先级,数字越小优先级越高,只有同一接口下优先级才生效
|
||||
"""
|
||||
return 4
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止模块
|
||||
"""
|
||||
for client in self.get_instances().values():
|
||||
client.stop()
|
||||
|
||||
def test(self) -> Optional[Tuple[bool, str]]:
|
||||
"""
|
||||
测试模块连接性
|
||||
"""
|
||||
if not self.get_instances():
|
||||
return None
|
||||
for name, client in self.get_instances().items():
|
||||
state = client.get_state()
|
||||
if not state:
|
||||
return False, f"Discord {name} Bot 未就绪"
|
||||
return True, ""
|
||||
|
||||
def init_setting(self) -> Tuple[str, Union[str, bool]]:
|
||||
pass
|
||||
|
||||
def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]:
|
||||
"""
|
||||
解析消息内容,返回字典,注意以下约定值:
|
||||
userid: 用户ID
|
||||
username: 用户名
|
||||
text: 内容
|
||||
:param source: 消息来源
|
||||
:param body: 请求体
|
||||
:param form: 表单
|
||||
:param args: 参数
|
||||
:return: 渠道、消息体
|
||||
"""
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
try:
|
||||
msg_json: dict = json.loads(body)
|
||||
except Exception as e:
|
||||
logger.debug(f"解析 Discord 消息失败:{str(e)}")
|
||||
return None
|
||||
|
||||
if not msg_json:
|
||||
return None
|
||||
|
||||
msg_type = msg_json.get("type")
|
||||
userid = msg_json.get("userid")
|
||||
username = msg_json.get("username")
|
||||
|
||||
if msg_type == "interaction":
|
||||
callback_data = msg_json.get("callback_data")
|
||||
message_id = msg_json.get("message_id")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if callback_data and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 按钮回调:"
|
||||
f"userid={userid}, username={username}, callback_data={callback_data}")
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
source=client_config.name,
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=f"CALLBACK:{callback_data}",
|
||||
is_callback=True,
|
||||
callback_data=callback_data,
|
||||
message_id=message_id,
|
||||
chat_id=str(chat_id) if chat_id else None
|
||||
)
|
||||
return None
|
||||
|
||||
if msg_type == "message":
|
||||
text = msg_json.get("text")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if text and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}")
|
||||
return CommingMessage(channel=MessageChannel.Discord, source=client_config.name,
|
||||
userid=userid, username=username, text=text,
|
||||
chat_id=str(chat_id) if chat_id else None)
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送通知消息
|
||||
:param message: 消息通知对象
|
||||
"""
|
||||
# DEBUG: Log entry and configs
|
||||
configs = self.get_configs()
|
||||
logger.debug(f"[Discord] post_message 被调用,message.source={message.source}, "
|
||||
f"message.userid={message.userid}, message.channel={message.channel}")
|
||||
logger.debug(f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}")
|
||||
logger.debug(f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}")
|
||||
|
||||
if not configs:
|
||||
logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
|
||||
return
|
||||
|
||||
for conf in configs.values():
|
||||
logger.debug(f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}")
|
||||
if not self.check_message(message, conf.name):
|
||||
logger.debug(f"[Discord] check_message 返回 False,跳过配置: {conf.name}")
|
||||
continue
|
||||
logger.debug(f"[Discord] check_message 通过,准备发送到: {conf.name}")
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
userid = targets.get('discord_userid')
|
||||
if not userid:
|
||||
logger.warn("用户没有指定 Discord 用户ID,消息无法发送")
|
||||
return
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
logger.debug(f"[Discord] get_instance('{conf.name}') 返回: {client is not None}")
|
||||
if client:
|
||||
logger.debug(f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}...")
|
||||
result = client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype)
|
||||
logger.debug(f"[Discord] send_msg 返回结果: {result}")
|
||||
else:
|
||||
logger.warning(f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例")
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
发送媒体信息选择列表
|
||||
:param message: 消息体
|
||||
:param medias: 媒体信息
|
||||
:return: 成功或失败
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
|
||||
"""
|
||||
发送种子信息选择列表
|
||||
:param message: 消息体
|
||||
:param torrents: 种子信息
|
||||
:return: 成功或失败
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_torrents_msg(title=message.title, torrents=torrents,
|
||||
userid=message.userid, buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
|
||||
def delete_message(self, channel: MessageChannel, source: str,
|
||||
message_id: str, chat_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
:param source: 指定的消息源
|
||||
:param message_id: 消息ID(Slack中为时间戳)
|
||||
:param chat_id: 聊天ID(频道ID)
|
||||
:return: 删除是否成功
|
||||
"""
|
||||
success = False
|
||||
for conf in self.get_configs().values():
|
||||
if channel != self._channel:
|
||||
break
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.delete_msg(message_id=message_id, chat_id=chat_id)
|
||||
if result:
|
||||
success = True
|
||||
return success
|
||||
714
app/modules/discord/discord.py
Normal file
714
app/modules/discord/discord.py
Normal file
@@ -0,0 +1,714 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.schemas.types import NotificationType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
# Discord embed 字段解析白名单
|
||||
# 只有这些消息类型会使用复杂的字段解析逻辑
|
||||
PARSE_FIELD_TYPES = {
|
||||
NotificationType.Download, # 资源下载
|
||||
NotificationType.Organize, # 整理入库
|
||||
NotificationType.Subscribe, # 订阅
|
||||
NotificationType.Manual, # 手动处理
|
||||
}
|
||||
|
||||
|
||||
class Discord:
|
||||
"""
|
||||
Discord Bot 通知与交互实现(基于 discord.py 2.6.4)
|
||||
"""
|
||||
|
||||
def __init__(self, DISCORD_BOT_TOKEN: Optional[str] = None,
|
||||
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
|
||||
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
|
||||
**kwargs):
|
||||
logger.debug(f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
|
||||
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
|
||||
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}")
|
||||
if not DISCORD_BOT_TOKEN:
|
||||
logger.error("Discord Bot Token 未配置!")
|
||||
return
|
||||
|
||||
self._token = DISCORD_BOT_TOKEN
|
||||
self._guild_id = self._to_int(DISCORD_GUILD_ID)
|
||||
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
|
||||
logger.debug(f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}")
|
||||
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
|
||||
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
|
||||
if kwargs.get("name"):
|
||||
# URL encode the source name to handle special characters in config names
|
||||
encoded_name = quote(kwargs.get('name'), safe='')
|
||||
self._ds_url = f"{self._ds_url}&source={encoded_name}"
|
||||
logger.debug(f"[Discord] 消息回调 URL: {self._ds_url}")
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.messages = True
|
||||
intents.guilds = True
|
||||
|
||||
self._client: Optional[discord.Client] = discord.Client(
|
||||
intents=intents,
|
||||
proxy=settings.PROXY_HOST
|
||||
)
|
||||
self._tree: Optional[app_commands.CommandTree] = None
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready_event = threading.Event()
|
||||
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
|
||||
self._user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
self._broadcast_channel = None
|
||||
self._bot_user_id: Optional[int] = None
|
||||
|
||||
self._register_events()
|
||||
self._start()
|
||||
|
||||
@staticmethod
|
||||
def _to_int(val: Optional[Union[str, int]]) -> Optional[int]:
|
||||
try:
|
||||
return int(val) if val is not None and str(val).strip() else None
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _register_events(self):
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
self._bot_user_id = self._client.user.id if self._client.user else None
|
||||
self._ready_event.set()
|
||||
logger.info(f"Discord Bot 已登录:{self._client.user}")
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: discord.Message):
|
||||
if message.author.bot:
|
||||
return
|
||||
if not self._should_process_message(message):
|
||||
return
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
self._update_user_chat_mapping(str(message.author.id), str(message.channel.id))
|
||||
|
||||
cleaned_text = self._clean_bot_mention(message.content or "")
|
||||
username = message.author.display_name or message.author.global_name or message.author.name
|
||||
payload = {
|
||||
"type": "message",
|
||||
"userid": str(message.author.id),
|
||||
"username": username,
|
||||
"user_tag": str(message.author),
|
||||
"text": cleaned_text,
|
||||
"message_id": str(message.id),
|
||||
"chat_id": str(message.channel.id),
|
||||
"channel_type": "dm" if isinstance(message.channel, discord.DMChannel) else "guild"
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
@self._client.event
|
||||
async def on_interaction(interaction: discord.Interaction):
|
||||
if interaction.type == discord.InteractionType.component:
|
||||
data = interaction.data or {}
|
||||
callback_data = data.get("custom_id")
|
||||
if not callback_data:
|
||||
return
|
||||
try:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Discord 交互响应失败:{e}")
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
if interaction.user and interaction.channel:
|
||||
self._update_user_chat_mapping(str(interaction.user.id), str(interaction.channel.id))
|
||||
|
||||
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
|
||||
if interaction.user else None
|
||||
payload = {
|
||||
"type": "interaction",
|
||||
"userid": str(interaction.user.id) if interaction.user else None,
|
||||
"username": username,
|
||||
"user_tag": str(interaction.user) if interaction.user else None,
|
||||
"callback_data": callback_data,
|
||||
"message_id": str(interaction.message.id) if interaction.message else None,
|
||||
"chat_id": str(interaction.channel.id) if interaction.channel else None
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
def _start(self):
|
||||
if self._thread:
|
||||
return
|
||||
|
||||
def runner():
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.create_task(self._client.start(self._token))
|
||||
self._loop.run_forever()
|
||||
except Exception as err:
|
||||
logger.error(f"Discord Bot 启动失败:{err}")
|
||||
finally:
|
||||
try:
|
||||
self._loop.run_until_complete(self._client.close())
|
||||
except Exception as err:
|
||||
logger.debug(f"Discord Bot 关闭失败:{err}")
|
||||
|
||||
self._thread = threading.Thread(target=runner, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
if not self._client or not self._loop or not self._thread:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(timeout=10)
|
||||
except Exception as err:
|
||||
logger.error(f"关闭 Discord Bot 失败:{err}")
|
||||
finally:
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
except Exception as err:
|
||||
logger.error(f"停止 Discord 事件循环失败:{err}")
|
||||
self._ready_event.clear()
|
||||
|
||||
def get_state(self) -> bool:
|
||||
return self._ready_event.is_set() and self._client is not None
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
mtype: Optional['NotificationType'] = None) -> Optional[bool]:
|
||||
logger.debug(f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}...")
|
||||
logger.debug(f"[Discord] get_state() = {self.get_state()}, "
|
||||
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
|
||||
f"_client = {self._client is not None}")
|
||||
if not self.get_state():
|
||||
logger.warning("[Discord] get_state() 返回 False,Bot 未就绪,无法发送消息")
|
||||
return False
|
||||
if not title and not text:
|
||||
logger.warn("标题和内容不能同时为空")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.debug(f"[Discord] 准备异步发送消息...")
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_message(title=title, text=text, image=image, userid=userid,
|
||||
link=link, buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
mtype=mtype),
|
||||
self._loop)
|
||||
result = future.result(timeout=30)
|
||||
logger.debug(f"[Discord] 异步发送完成,结果: {result}")
|
||||
return result
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state() or not medias:
|
||||
return False
|
||||
title = title or "媒体列表"
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_list_message(
|
||||
embeds=self._build_media_embeds(medias, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(medias)) if not buttons else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 媒体列表失败:{err}")
|
||||
return False
|
||||
|
||||
def send_torrents_msg(self, torrents: List[Context], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state() or not torrents:
|
||||
return False
|
||||
title = title or "种子列表"
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_list_message(
|
||||
embeds=self._build_torrent_embeds(torrents, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(torrents)) if not buttons else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 种子列表失败:{err}")
|
||||
return False
|
||||
|
||||
def delete_msg(self, message_id: Union[str, int], chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._delete_message(message_id=message_id, chat_id=chat_id),
|
||||
self._loop
|
||||
)
|
||||
return future.result(timeout=15)
|
||||
except Exception as err:
|
||||
logger.error(f"删除 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _send_message(self, title: str, text: Optional[str], image: Optional[str],
|
||||
userid: Optional[str], link: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional['NotificationType'] = None) -> bool:
|
||||
logger.debug(f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}")
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
logger.debug(f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}")
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
|
||||
embed = self._build_embed(title=title, text=text, image=image, link=link, mtype=mtype)
|
||||
view = self._build_view(buttons=buttons, link=link)
|
||||
content = None
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
logger.debug(f"[Discord] 编辑现有消息: message_id={original_message_id}")
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=content, embed=embed, view=view)
|
||||
|
||||
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
|
||||
try:
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
logger.debug("[Discord] 消息发送成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息到频道失败: {e}")
|
||||
return False
|
||||
|
||||
async def _send_list_message(self, embeds: List[discord.Embed],
|
||||
userid: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
fallback_buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str]) -> bool:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
|
||||
view = self._build_view(buttons=buttons if buttons else fallback_buttons)
|
||||
embeds = embeds[:10] if embeds else [] # Discord 单条消息最多 10 个 embed
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=None, embed=None, view=view, embeds=embeds)
|
||||
|
||||
await channel.send(embed=embeds[0] if len(embeds) == 1 else None,
|
||||
embeds=embeds if len(embeds) > 1 else None,
|
||||
view=view)
|
||||
return True
|
||||
|
||||
async def _edit_message(self, chat_id: Union[str, int], message_id: Union[str, int],
|
||||
content: Optional[str], embed: Optional[discord.Embed],
|
||||
view: Optional[discord.ui.View], embeds: Optional[List[discord.Embed]] = None) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=str(chat_id))
|
||||
if not channel:
|
||||
logger.error(f"未找到要编辑的 Discord 频道:{chat_id}")
|
||||
return False
|
||||
try:
|
||||
message = await channel.fetch_message(int(message_id))
|
||||
kwargs: Dict[str, Any] = {"content": content, "view": view}
|
||||
if embeds:
|
||||
if len(embeds) == 1:
|
||||
kwargs["embed"] = embeds[0]
|
||||
else:
|
||||
kwargs["embeds"] = embeds
|
||||
elif embed:
|
||||
kwargs["embed"] = embed
|
||||
await message.edit(**kwargs)
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"编辑 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _delete_message(self, message_id: Union[str, int], chat_id: Optional[str]) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=chat_id)
|
||||
if not channel:
|
||||
logger.error("删除 Discord 消息时未找到频道")
|
||||
return False
|
||||
try:
|
||||
message = await channel.fetch_message(int(message_id))
|
||||
await message.delete()
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"删除 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_embed(title: str, text: Optional[str], image: Optional[str],
|
||||
link: Optional[str], mtype: Optional['NotificationType'] = None) -> discord.Embed:
|
||||
fields: List[Dict[str, str]] = []
|
||||
desc_lines: List[str] = []
|
||||
should_parse_fields = mtype in PARSE_FIELD_TYPES if mtype else False
|
||||
def _collect_spans(s: str, left: str, right: str) -> List[Tuple[int, int]]:
|
||||
spans: List[Tuple[int, int]] = []
|
||||
start = 0
|
||||
while True:
|
||||
l_idx = s.find(left, start)
|
||||
if l_idx == -1:
|
||||
break
|
||||
r_idx = s.find(right, l_idx + 1)
|
||||
if r_idx == -1:
|
||||
break
|
||||
spans.append((l_idx, r_idx))
|
||||
start = r_idx + 1
|
||||
return spans
|
||||
|
||||
def _find_colon_index(s: str, m: re.Match) -> Optional[int]:
|
||||
segment = s[m.start():m.end()]
|
||||
for i, ch in enumerate(segment):
|
||||
if ch in (":", ":"):
|
||||
return m.start() + i
|
||||
return None
|
||||
|
||||
if text:
|
||||
# 处理上游未反序列化的 "\n" 等转义换行,避免被当成普通字符
|
||||
if "\\n" in text or "\\r" in text:
|
||||
text = text.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\r", "\n")
|
||||
if not should_parse_fields:
|
||||
desc_lines.append(text.strip())
|
||||
else:
|
||||
# 匹配形如 "字段:值" 的片段,字段名不允许包含常见分隔符;
|
||||
# 下一个字段需以顿号/逗号/分号等分隔开,且不能是 URL 协议开头,避免值里出现 URL 的":" 被误拆
|
||||
# 字段名允许 emoji 等 Unicode 字符,但排除空白/分隔符/冒号
|
||||
name_re = r"[^\s::,,。;;、]+"
|
||||
pair_pattern = re.compile(
|
||||
rf"({name_re})[::](.*?)(?=(?:[,,。;;、]+\s*(?!https?://|ftp://|ftps://|magnet:){name_re}[::])|$)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
matches = list(pair_pattern.finditer(line))
|
||||
if matches:
|
||||
book_spans = _collect_spans(line, "《", "》") + _collect_spans(line, "【", "】")
|
||||
if book_spans:
|
||||
has_book_colon = False
|
||||
for m in matches:
|
||||
colon_idx = _find_colon_index(line, m)
|
||||
if colon_idx is not None and any(l < colon_idx < r for l, r in book_spans):
|
||||
has_book_colon = True
|
||||
break
|
||||
if has_book_colon:
|
||||
desc_lines.append(line)
|
||||
continue
|
||||
# 若整行只是 URL/时间等自然包含":"的内容,则不当作字段
|
||||
url_like_names = {"http", "https", "ftp", "ftps", "magnet"}
|
||||
if all(m.group(1).lower() in url_like_names or m.group(1).isdigit() for m in matches):
|
||||
desc_lines.append(line)
|
||||
continue
|
||||
last_end = 0
|
||||
for m in matches:
|
||||
# 追加匹配前的非空文本到描述
|
||||
prefix = line[last_end:m.start()].strip(" ,,;;。、")
|
||||
# 仅当前缀不全是分隔符/空白时才记录
|
||||
if prefix and prefix.strip(" ,,;;。、"):
|
||||
desc_lines.append(prefix)
|
||||
name = m.group(1).strip()
|
||||
value = m.group(2).strip(" ,,;;。、\t") or "-"
|
||||
if name:
|
||||
fields.append({"name": name, "value": value, "inline": False})
|
||||
last_end = m.end()
|
||||
# 匹配末尾后的文本
|
||||
suffix = line[last_end:].strip(" ,,;;。、")
|
||||
if suffix and suffix.strip(" ,,;;。、"):
|
||||
desc_lines.append(suffix)
|
||||
else:
|
||||
desc_lines.append(line)
|
||||
description = "\n".join(desc_lines).strip()
|
||||
if not description and not fields and text:
|
||||
description = text.strip()
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
url=link or "https://github.com/jxxghp/MoviePilot",
|
||||
description=description if description else None,
|
||||
color=0xE67E22
|
||||
)
|
||||
for field in fields:
|
||||
embed.add_field(name=field["name"], value=field["value"], inline=False)
|
||||
if image:
|
||||
embed.set_image(url=image)
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
def _build_media_embeds(medias: List[MediaInfo], title: str) -> List[discord.Embed]:
|
||||
embeds: List[discord.Embed] = []
|
||||
for index, media in enumerate(medias[:10], start=1):
|
||||
overview = media.get_overview_string(80)
|
||||
desc_parts = [
|
||||
f"{media.type.value} | {media.vote_star}" if media.vote_star else media.type.value,
|
||||
overview
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {media.title_year}",
|
||||
url=media.detail_link or discord.Embed.Empty,
|
||||
description="\n".join([p for p in desc_parts if p]),
|
||||
color=0x5865F2
|
||||
)
|
||||
if media.get_poster_image():
|
||||
embed.set_thumbnail(url=media.get_poster_image())
|
||||
embeds.append(embed)
|
||||
if embeds:
|
||||
embeds[0].set_author(name=title)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def _build_torrent_embeds(torrents: List[Context], title: str) -> List[discord.Embed]:
|
||||
embeds: List[discord.Embed] = []
|
||||
for index, context in enumerate(torrents[:10], start=1):
|
||||
torrent = context.torrent_info
|
||||
meta = MetaInfo(torrent.title, torrent.description)
|
||||
title_text = f"{meta.season_episode} {meta.resource_term} {meta.video_term} {meta.release_group}"
|
||||
title_text = re.sub(r"\s+", " ", title_text).strip()
|
||||
detail = [
|
||||
f"{torrent.site_name} | {StringUtils.str_filesize(torrent.size)} | {torrent.volume_factor} | {torrent.seeders}↑",
|
||||
meta.resource_term,
|
||||
meta.video_term
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {title_text or torrent.title}",
|
||||
url=torrent.page_url or discord.Embed.Empty,
|
||||
description="\n".join([d for d in detail if d]),
|
||||
color=0x00A86B
|
||||
)
|
||||
poster = getattr(torrent, "poster", None)
|
||||
if poster:
|
||||
embed.set_thumbnail(url=poster)
|
||||
embeds.append(embed)
|
||||
if embeds:
|
||||
embeds[0].set_author(name=title)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def _build_default_buttons(count: int) -> List[List[dict]]:
|
||||
buttons: List[List[dict]] = []
|
||||
max_rows = 5
|
||||
max_per_row = 5
|
||||
capped = min(count, max_rows * max_per_row)
|
||||
for idx in range(1, capped + 1):
|
||||
row_idx = (idx - 1) // max_per_row
|
||||
if len(buttons) <= row_idx:
|
||||
buttons.append([])
|
||||
buttons[row_idx].append({"text": f"选择 {idx}", "callback_data": str(idx)})
|
||||
if count > capped:
|
||||
logger.warn(f"按钮数量超过 Discord 限制,仅展示前 {capped} 个")
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _build_view(buttons: Optional[List[List[dict]]], link: Optional[str] = None) -> Optional[discord.ui.View]:
|
||||
has_buttons = buttons and any(buttons)
|
||||
if not has_buttons and not link:
|
||||
return None
|
||||
|
||||
view = discord.ui.View(timeout=None)
|
||||
if buttons:
|
||||
for row_index, button_row in enumerate(buttons[:5]):
|
||||
for button in button_row[:5]:
|
||||
if "url" in button:
|
||||
btn = discord.ui.Button(label=button.get("text", "链接"),
|
||||
url=button["url"],
|
||||
style=discord.ButtonStyle.link)
|
||||
else:
|
||||
custom_id = (button.get("callback_data") or button.get("text") or f"btn-{row_index}")[:99]
|
||||
btn = discord.ui.Button(label=button.get("text", "选择")[:80],
|
||||
custom_id=custom_id,
|
||||
style=discord.ButtonStyle.primary)
|
||||
view.add_item(btn)
|
||||
elif link:
|
||||
view.add_item(discord.ui.Button(label="查看详情", url=link, style=discord.ButtonStyle.link))
|
||||
return view
|
||||
|
||||
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
|
||||
"""
|
||||
Resolve the channel to send messages to.
|
||||
Priority order:
|
||||
1. `chat_id` (original channel where user sent the message) - for contextual replies
|
||||
2. `userid` mapping (channel where user last sent a message) - for contextual replies
|
||||
3. Configured `_channel_id` (broadcast channel) - for system notifications
|
||||
4. Any available text channel in configured guild - fallback
|
||||
5. `userid` (DM) - for private conversations as a final fallback
|
||||
"""
|
||||
logger.debug(f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
|
||||
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}")
|
||||
|
||||
# Priority 1: Use explicit chat_id (reply to the same channel where user sent message)
|
||||
if chat_id:
|
||||
logger.debug(f"[Discord] 尝试通过 chat_id={chat_id} 获取原始频道")
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过 get_channel 找到频道: {channel}")
|
||||
return channel
|
||||
try:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
logger.debug(f"[Discord] 通过 fetch_channel 找到频道: {channel}")
|
||||
return channel
|
||||
except Exception as err:
|
||||
logger.warn(f"通过 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# Priority 2: Use user-chat mapping (reply to where the user last sent a message)
|
||||
if userid:
|
||||
mapped_chat_id = self._get_user_chat_id(str(userid))
|
||||
if mapped_chat_id:
|
||||
logger.debug(f"[Discord] 从用户映射获取 chat_id={mapped_chat_id}")
|
||||
channel = self._client.get_channel(int(mapped_chat_id))
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过映射找到频道: {channel}")
|
||||
return channel
|
||||
try:
|
||||
channel = await self._client.fetch_channel(int(mapped_chat_id))
|
||||
logger.debug(f"[Discord] 通过 fetch_channel 找到映射频道: {channel}")
|
||||
return channel
|
||||
except Exception as err:
|
||||
logger.warn(f"通过映射的 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# Priority 3: Use configured broadcast channel (for system notifications)
|
||||
if self._broadcast_channel:
|
||||
logger.debug(f"[Discord] 使用缓存的广播频道: {self._broadcast_channel}")
|
||||
return self._broadcast_channel
|
||||
if self._channel_id:
|
||||
logger.debug(f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道")
|
||||
channel = self._client.get_channel(self._channel_id)
|
||||
if not channel:
|
||||
try:
|
||||
channel = await self._client.fetch_channel(self._channel_id)
|
||||
except Exception as err:
|
||||
logger.warn(f"通过配置的频道ID获取 Discord 频道失败:{err}")
|
||||
channel = None
|
||||
self._broadcast_channel = channel
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过配置的频道ID找到频道: {channel}")
|
||||
return channel
|
||||
|
||||
# Priority 4: Find any available text channel in guild (fallback)
|
||||
logger.debug(f"[Discord] 尝试在 Guild 中寻找可用频道")
|
||||
target_guilds = []
|
||||
if self._guild_id:
|
||||
guild = self._client.get_guild(self._guild_id)
|
||||
if guild:
|
||||
target_guilds.append(guild)
|
||||
else:
|
||||
target_guilds = list(self._client.guilds)
|
||||
logger.debug(f"[Discord] 目标 Guilds 数量: {len(target_guilds)}")
|
||||
|
||||
for guild in target_guilds:
|
||||
for channel in guild.text_channels:
|
||||
if guild.me and channel.permissions_for(guild.me).send_messages:
|
||||
logger.debug(f"[Discord] 在 Guild 中找到可用频道: {channel}")
|
||||
self._broadcast_channel = channel
|
||||
return channel
|
||||
|
||||
# Priority 5: Fallback to DM (only if no channel available)
|
||||
if userid:
|
||||
logger.debug(f"[Discord] 回退到私聊: userid={userid}")
|
||||
dm = await self._get_dm_channel(str(userid))
|
||||
if dm:
|
||||
logger.debug(f"[Discord] 获取到私聊频道: {dm}")
|
||||
return dm
|
||||
else:
|
||||
logger.debug(f"[Discord] 无法获取用户 {userid} 的私聊频道")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
|
||||
logger.debug(f"[Discord] _get_dm_channel: userid={userid}")
|
||||
if userid in self._user_dm_cache:
|
||||
logger.debug(f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}")
|
||||
return self._user_dm_cache.get(userid)
|
||||
try:
|
||||
logger.debug(f"[Discord] 尝试获取/创建用户 {userid} 的私聊频道")
|
||||
user_obj = self._client.get_user(int(userid))
|
||||
logger.debug(f"[Discord] get_user 结果: {user_obj}")
|
||||
if not user_obj:
|
||||
user_obj = await self._client.fetch_user(int(userid))
|
||||
logger.debug(f"[Discord] fetch_user 结果: {user_obj}")
|
||||
if not user_obj:
|
||||
logger.debug(f"[Discord] 无法找到用户 {userid}")
|
||||
return None
|
||||
dm = user_obj.dm_channel
|
||||
logger.debug(f"[Discord] 用户现有 dm_channel: {dm}")
|
||||
if not dm:
|
||||
dm = await user_obj.create_dm()
|
||||
logger.debug(f"[Discord] 创建新的 dm_channel: {dm}")
|
||||
if dm:
|
||||
self._user_dm_cache[userid] = dm
|
||||
return dm
|
||||
except Exception as err:
|
||||
logger.error(f"获取 Discord 私聊失败:{err}")
|
||||
return None
|
||||
|
||||
def _update_user_chat_mapping(self, userid: str, chat_id: str) -> None:
|
||||
"""
|
||||
Update user-chat mapping for reply targeting.
|
||||
This ensures replies go to the same channel where the user sent the message.
|
||||
:param userid: User ID
|
||||
:param chat_id: Channel/Chat ID where the user sent the message
|
||||
"""
|
||||
if userid and chat_id:
|
||||
self._user_chat_mapping[userid] = chat_id
|
||||
logger.debug(f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}")
|
||||
|
||||
def _get_user_chat_id(self, userid: str) -> Optional[str]:
|
||||
"""
|
||||
Get the chat ID where the user last sent a message.
|
||||
:param userid: User ID
|
||||
:return: Chat ID or None if not found
|
||||
"""
|
||||
return self._user_chat_mapping.get(userid)
|
||||
|
||||
def _should_process_message(self, message: discord.Message) -> bool:
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return True
|
||||
content = message.content or ""
|
||||
# 仅处理 @Bot 或斜杠命令
|
||||
if self._client.user and self._client.user.mentioned_in(message):
|
||||
return True
|
||||
if content.startswith("/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clean_bot_mention(self, content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
if self._bot_user_id:
|
||||
mention_pattern = rf"<@!?{self._bot_user_id}>"
|
||||
content = re.sub(mention_pattern, "", content).strip()
|
||||
return content
|
||||
|
||||
async def _post_to_ds(self, payload: Dict[str, Any]) -> None:
|
||||
try:
|
||||
proxy = None
|
||||
if settings.PROXY:
|
||||
proxy = settings.PROXY.get("https") or settings.PROXY.get("http")
|
||||
async with httpx.AsyncClient(timeout=10, verify=False, proxy=proxy) as client:
|
||||
await client.post(self._ds_url, json=payload)
|
||||
except Exception as err:
|
||||
logger.error(f"转发 Discord 消息失败:{err}")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user