mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-10 17:42:45 +08:00
Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ded58adbb | ||
|
|
019a077407 | ||
|
|
0f190057d3 | ||
|
|
840c8f7298 | ||
|
|
6a6bcf59a0 | ||
|
|
323844b26d | ||
|
|
140d224a9a | ||
|
|
7bc032d17c | ||
|
|
2df476dbff | ||
|
|
bae086d8b8 | ||
|
|
221eb21694 | ||
|
|
4208c79d72 | ||
|
|
90245a13e1 | ||
|
|
b5979b9b09 | ||
|
|
0277288a41 | ||
|
|
79bfeaf2af | ||
|
|
4fe41ba5e9 | ||
|
|
14d6e2febc | ||
|
|
97c7e71207 | ||
|
|
8f29a218ea | ||
|
|
4fd5aa3eb6 | ||
|
|
bfc27d151c | ||
|
|
f2b56b8f40 | ||
|
|
a05ffc07d4 | ||
|
|
4a81417fb7 | ||
|
|
c7fa3dc863 | ||
|
|
28f9756dd6 | ||
|
|
4bffe2cff1 | ||
|
|
fca478f1d8 | ||
|
|
097dff13a3 | ||
|
|
460b386004 | ||
|
|
89bf89c02d | ||
|
|
cefb60ba2c | ||
|
|
8c78627647 |
@@ -4,7 +4,8 @@ import re
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
@@ -24,6 +25,7 @@ from app.agent.middleware.jobs import JobsMiddleware
|
||||
from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
@@ -41,6 +43,39 @@ class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionUsageSnapshot:
|
||||
model: Optional[str] = None
|
||||
context_window_tokens: Optional[int] = None
|
||||
last_input_tokens: int = 0
|
||||
last_output_tokens: int = 0
|
||||
last_total_tokens: int = 0
|
||||
last_context_usage_ratio: Optional[float] = None
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
model_call_count: int = 0
|
||||
last_updated_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self, session_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"model": self.model,
|
||||
"context_window_tokens": self.context_window_tokens,
|
||||
"last_input_tokens": self.last_input_tokens,
|
||||
"last_output_tokens": self.last_output_tokens,
|
||||
"last_total_tokens": self.last_total_tokens,
|
||||
"last_context_usage_ratio": self.last_context_usage_ratio,
|
||||
"total_input_tokens": self.total_input_tokens,
|
||||
"total_output_tokens": self.total_output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"model_call_count": self.model_call_count,
|
||||
"last_updated_at": self.last_updated_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.last_updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
class _ThinkTagStripper:
|
||||
"""
|
||||
流式剥离 <think>...</think> 标签的辅助类。
|
||||
@@ -73,7 +108,7 @@ class _ThinkTagStripper:
|
||||
on_output(self.buffer[:start_idx])
|
||||
emitted = True
|
||||
self.in_think_tag = True
|
||||
self.buffer = self.buffer[start_idx + 7:]
|
||||
self.buffer = self.buffer[start_idx + 7 :]
|
||||
else:
|
||||
# 检查是否以 <think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
@@ -93,7 +128,7 @@ class _ThinkTagStripper:
|
||||
end_idx = self.buffer.find("</think>")
|
||||
if end_idx != -1:
|
||||
self.in_think_tag = False
|
||||
self.buffer = self.buffer[end_idx + 8:]
|
||||
self.buffer = self.buffer[end_idx + 8 :]
|
||||
else:
|
||||
# 检查是否以 </think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
@@ -138,10 +173,92 @@ class MoviePilotAgent:
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_model_name(cls, llm: Any) -> Optional[str]:
|
||||
return (
|
||||
getattr(llm, "model", None)
|
||||
or getattr(llm, "model_name", None)
|
||||
or getattr(llm, "model_id", None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_context_window_tokens(cls, llm: Any) -> Optional[int]:
|
||||
profile = getattr(llm, "profile", None)
|
||||
if not profile:
|
||||
return None
|
||||
if isinstance(profile, dict):
|
||||
return cls._coerce_int(
|
||||
profile.get("max_input_tokens") or profile.get("input_token_limit")
|
||||
)
|
||||
return cls._coerce_int(
|
||||
getattr(profile, "max_input_tokens", None)
|
||||
or getattr(profile, "input_token_limit", None)
|
||||
)
|
||||
|
||||
def _sync_model_profile(self, llm: Any) -> None:
|
||||
model_name = self._get_model_name(llm)
|
||||
context_window_tokens = self._get_context_window_tokens(llm)
|
||||
if model_name:
|
||||
self._session_usage.model = model_name
|
||||
if context_window_tokens:
|
||||
self._session_usage.context_window_tokens = context_window_tokens
|
||||
|
||||
def _record_usage(self, usage: dict[str, Any]) -> None:
|
||||
if not usage:
|
||||
return
|
||||
|
||||
model_name = usage.get("model")
|
||||
context_window_tokens = self._coerce_int(usage.get("context_window_tokens"))
|
||||
if model_name:
|
||||
self._session_usage.model = model_name
|
||||
if context_window_tokens:
|
||||
self._session_usage.context_window_tokens = context_window_tokens
|
||||
|
||||
self._session_usage.model_call_count += 1
|
||||
self._session_usage.last_updated_at = datetime.now()
|
||||
|
||||
if not usage.get("has_usage"):
|
||||
return
|
||||
|
||||
input_tokens = self._coerce_int(usage.get("input_tokens")) or 0
|
||||
output_tokens = self._coerce_int(usage.get("output_tokens")) or 0
|
||||
total_tokens = self._coerce_int(usage.get("total_tokens"))
|
||||
if total_tokens is None:
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
self._session_usage.last_input_tokens = input_tokens
|
||||
self._session_usage.last_output_tokens = output_tokens
|
||||
self._session_usage.last_total_tokens = total_tokens
|
||||
self._session_usage.last_context_usage_ratio = usage.get("context_usage_ratio")
|
||||
self._session_usage.total_input_tokens += input_tokens
|
||||
self._session_usage.total_output_tokens += output_tokens
|
||||
self._session_usage.total_tokens += total_tokens
|
||||
|
||||
def get_session_status(self) -> dict[str, Any]:
|
||||
if not self._session_usage.model:
|
||||
self._session_usage.model = settings.LLM_MODEL
|
||||
if not self._session_usage.context_window_tokens:
|
||||
self._session_usage.context_window_tokens = (
|
||||
settings.LLM_MAX_CONTEXT_TOKENS * 1000
|
||||
if settings.LLM_MAX_CONTEXT_TOKENS
|
||||
else None
|
||||
)
|
||||
return self._session_usage.to_dict(self.session_id)
|
||||
|
||||
@property
|
||||
def is_background(self) -> bool:
|
||||
"""
|
||||
@@ -258,6 +375,12 @@ class MoviePilotAgent:
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm(streaming=streaming)
|
||||
self._sync_model_profile(llm)
|
||||
|
||||
# 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。
|
||||
non_streaming_llm = (
|
||||
llm if not streaming else self._initialize_llm(streaming=False)
|
||||
)
|
||||
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
@@ -279,8 +402,12 @@ class MoviePilotAgent:
|
||||
ActivityLogMiddleware(
|
||||
activity_dir=str(settings.CONFIG_PATH / "agent" / "activity"),
|
||||
),
|
||||
# 用量统计
|
||||
UsageMiddleware(on_usage=self._record_usage),
|
||||
# 上下文压缩
|
||||
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
|
||||
SummarizationMiddleware(
|
||||
model=non_streaming_llm, trigger=("fraction", 0.85)
|
||||
),
|
||||
# 错误工具调用修复
|
||||
PatchToolCallsMiddleware(),
|
||||
]
|
||||
@@ -289,7 +416,8 @@ class MoviePilotAgent:
|
||||
if settings.LLM_MAX_TOOLS > 0:
|
||||
middlewares.append(
|
||||
LLMToolSelectorMiddleware(
|
||||
model=llm, max_tools=settings.LLM_MAX_TOOLS
|
||||
model=non_streaming_llm,
|
||||
max_tools=settings.LLM_MAX_TOOLS,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -371,10 +499,6 @@ class MoviePilotAgent:
|
||||
:param on_token: 收到有效 token 时的回调
|
||||
"""
|
||||
stripper = _ThinkTagStripper()
|
||||
# 非VERBOSE模式下,跟踪当前langgraph_step以检测中间步骤的模型输出
|
||||
# 当模型在工具调用之前输出的"计划/思考"文本,会在检测到tool_call时被清除
|
||||
current_model_step = -1
|
||||
has_emitted_in_step = False
|
||||
|
||||
async for chunk in agent.astream(
|
||||
messages,
|
||||
@@ -388,25 +512,13 @@ class MoviePilotAgent:
|
||||
if not token or not hasattr(token, "tool_call_chunks"):
|
||||
continue
|
||||
|
||||
# 获取当前步骤信息
|
||||
step = metadata.get("langgraph_step", -1) if metadata else -1
|
||||
|
||||
if token.tool_call_chunks:
|
||||
# 检测到工具调用token:说明当前步骤是中间步骤
|
||||
# 非VERBOSE模式下,清除该步骤之前输出的"计划/思考"文本
|
||||
if not settings.AI_AGENT_VERBOSE and has_emitted_in_step:
|
||||
self.stream_handler.reset()
|
||||
stripper.reset()
|
||||
has_emitted_in_step = False
|
||||
# 清除 stripper 内部缓冲中可能残留的 <think> 标签中间状态
|
||||
stripper.reset()
|
||||
continue
|
||||
|
||||
# 以下处理纯文本token(tool_call_chunks为空)
|
||||
|
||||
# 检测步骤变化,重置步骤内emit跟踪
|
||||
if step != current_model_step:
|
||||
current_model_step = step
|
||||
has_emitted_in_step = False
|
||||
|
||||
# 跳过模型思考/推理内容(如 DeepSeek R1 的 reasoning_content)
|
||||
additional = getattr(token, "additional_kwargs", None)
|
||||
if additional and additional.get("reasoning_content"):
|
||||
@@ -416,8 +528,7 @@ class MoviePilotAgent:
|
||||
# content 可能是字符串或内容块列表,过滤掉思考类型的块
|
||||
content = self._extract_text_content(token.content)
|
||||
if content:
|
||||
if stripper.process(content, on_token):
|
||||
has_emitted_in_step = True
|
||||
stripper.process(content, on_token)
|
||||
|
||||
stripper.flush(on_token)
|
||||
|
||||
@@ -457,7 +568,10 @@ class MoviePilotAgent:
|
||||
agent=agent,
|
||||
messages={"messages": messages},
|
||||
config=agent_config,
|
||||
on_token=lambda token: (self.stream_handler.emit(token), self._emit_output(token)),
|
||||
on_token=lambda token: (
|
||||
self.stream_handler.emit(token),
|
||||
self._emit_output(token),
|
||||
),
|
||||
)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
|
||||
@@ -622,6 +736,37 @@ class AgentManager:
|
||||
# 重试整理缓冲区锁
|
||||
self._retry_transfer_lock = asyncio.Lock()
|
||||
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话当前模型与 token 使用状态。"""
|
||||
agent = self.active_agents.get(session_id)
|
||||
if agent:
|
||||
status = agent.get_session_status()
|
||||
else:
|
||||
status = {
|
||||
"session_id": session_id,
|
||||
"model": settings.LLM_MODEL,
|
||||
"context_window_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000
|
||||
if settings.LLM_MAX_CONTEXT_TOKENS
|
||||
else None,
|
||||
"last_input_tokens": 0,
|
||||
"last_output_tokens": 0,
|
||||
"last_total_tokens": 0,
|
||||
"last_context_usage_ratio": None,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"model_call_count": 0,
|
||||
"last_updated_at": None,
|
||||
}
|
||||
|
||||
queue = self._session_queues.get(session_id)
|
||||
status["pending_messages"] = queue.qsize() if queue else 0
|
||||
status["is_processing"] = (
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
)
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
"""
|
||||
@@ -1004,7 +1149,6 @@ class AgentManager:
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
await self.process_message(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -2,6 +2,8 @@ import asyncio
|
||||
import threading
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -180,7 +182,7 @@ class StreamingHandler:
|
||||
# 检查是否所有缓冲内容都已发送
|
||||
with self._lock:
|
||||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||||
current_msg_text = self._buffer[self._msg_start_offset :]
|
||||
current_msg_text = self._buffer[self._msg_start_offset:]
|
||||
all_sent = (
|
||||
self._message_response is not None
|
||||
and self._sent_text
|
||||
@@ -246,7 +248,7 @@ class StreamingHandler:
|
||||
"""
|
||||
with self._lock:
|
||||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||||
current_text = self._buffer[self._msg_start_offset :]
|
||||
current_text = self._buffer[self._msg_start_offset:]
|
||||
if not current_text or current_text == self._sent_text:
|
||||
# 没有新内容需要刷新
|
||||
return
|
||||
@@ -256,7 +258,8 @@ class StreamingHandler:
|
||||
try:
|
||||
if self._message_response is None:
|
||||
# 第一次发送:发送新消息并获取 message_id
|
||||
response = chain.send_direct_message(
|
||||
response = await run_in_threadpool(
|
||||
chain.send_direct_message,
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
@@ -264,7 +267,7 @@ class StreamingHandler:
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
),
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
@@ -291,13 +294,14 @@ class StreamingHandler:
|
||||
)
|
||||
with self._lock:
|
||||
self._msg_start_offset += len(self._sent_text)
|
||||
current_text = self._buffer[self._msg_start_offset :]
|
||||
current_text = self._buffer[self._msg_start_offset:]
|
||||
self._message_response = None
|
||||
self._sent_text = ""
|
||||
|
||||
# 如果偏移后还有新内容,立即发送为新消息
|
||||
if current_text:
|
||||
response = chain.send_direct_message(
|
||||
response = await run_in_threadpool(
|
||||
chain.send_direct_message,
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
@@ -305,7 +309,7 @@ class StreamingHandler:
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
),
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
@@ -324,7 +328,8 @@ class StreamingHandler:
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
|
||||
success = chain.edit_message(
|
||||
success = await run_in_threadpool(
|
||||
chain.edit_message,
|
||||
channel=channel_enum,
|
||||
source=self._message_response.source,
|
||||
message_id=self._message_response.message_id,
|
||||
@@ -360,3 +365,11 @@ class StreamingHandler:
|
||||
是否已经通过流式输出发送过消息(当前轮次)
|
||||
"""
|
||||
return self._message_response is not None
|
||||
|
||||
@property
|
||||
def last_buffer_char(self) -> str:
|
||||
"""
|
||||
返回当前缓冲区最后一个字符;缓冲区为空时返回空字符串。
|
||||
"""
|
||||
with self._lock:
|
||||
return self._buffer[-1:] if self._buffer else ""
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Agent 客户端交互请求管理。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInteractionOption:
|
||||
"""交互选项。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentInteraction:
|
||||
"""待处理的 Agent 客户端交互请求。"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentInteractionManager:
|
||||
"""管理 Agent 发起的客户端交互请求。"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
return None
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self):
|
||||
with self._lock:
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
@@ -124,34 +124,29 @@ Default memory file: {memory_file}
|
||||
</agent_memory>
|
||||
|
||||
<memory_onboarding>
|
||||
**IMPORTANT — First-time user detected!**
|
||||
First-time user detected.
|
||||
|
||||
The memory directory is currently empty. This means this is likely the user's first interaction, or their preferences have been reset.
|
||||
The memory directory is currently empty. This likely means the user has no saved long-term preferences yet.
|
||||
|
||||
**Your MANDATORY first action in this conversation:**
|
||||
Before doing ANYTHING else (before answering questions, before calling tools, before performing any task), you MUST proactively greet the user warmly and ask them about their preferences so you can provide personalized service going forward. Specifically, ask about:
|
||||
**Behavior requirements:**
|
||||
- Do NOT interrupt the current task just to collect preferences.
|
||||
- Do NOT proactively greet warmly, build rapport, or ask a long onboarding questionnaire.
|
||||
- Default to a concise, professional style until the user states a preference.
|
||||
- Only ask for preferences when they are directly useful for the current task, or when a short follow-up question at the end would clearly help future interactions.
|
||||
|
||||
1. **How to address the user** — Ask what name or nickname they'd like you to call them (e.g., a real name, a nickname, or a fun title). This is the top priority for building a personal connection.
|
||||
2. **Communication style preference** — Do they prefer a cute/playful tone (with emojis), a formal/professional tone, a concise/minimalist style, or something else?
|
||||
3. **Media preferences** — What types of media do they primarily care about? (e.g., movies, TV shows, anime, documentaries, etc.)
|
||||
4. **Quality preferences** — Do they have preferred video quality (4K, 1080p), codecs (H.265, H.264), or subtitle language preferences?
|
||||
5. **Any other special requests** — Anything else they'd like you to always keep in mind?
|
||||
**What to collect when useful:**
|
||||
- Preferred communication style
|
||||
- Media interests
|
||||
- Quality / codec / subtitle preferences
|
||||
- Any standing rules the user wants you to follow
|
||||
|
||||
**After the user replies**, you MUST immediately:
|
||||
1. Use the `write_file` tool to save ALL their preferences to the memory file at: `{memory_file}`
|
||||
2. Format the memory file in clean Markdown with clear sections (e.g., `## User Profile`, `## Communication Style`, `## Media Preferences`, etc.)
|
||||
3. The `## User Profile` section MUST include the user's preferred name/nickname at the top
|
||||
4. Only AFTER saving the preferences, proceed to help with whatever the user originally asked about (if anything)
|
||||
5. From this point on, always address the user by their preferred name/nickname in conversations
|
||||
6. You may also create additional `.md` files in the memory directory (`{memory_dir}`) for different topics as needed.
|
||||
**When the user provides lasting preferences**, you MUST promptly save them to `{memory_file}` using `write_file` or `edit_file`.
|
||||
|
||||
**If the user skips the preference questions** and directly asks you to do something:
|
||||
- Go ahead and help them with their request first
|
||||
- But still ask about their preferences naturally at the end of the interaction
|
||||
- Save whatever you learn about them (implicit or explicit) to the memory file
|
||||
|
||||
**Example onboarding flow:**
|
||||
The greeting should introduce yourself, explain this is the first meeting, and ask the above questions in a numbered list. Adapt the tone to your persona defined in the base system prompt.
|
||||
**Memory format requirements:**
|
||||
- Use clean Markdown with short sections.
|
||||
- Record only durable preferences and working rules.
|
||||
- Do NOT invent personal details or preferred names.
|
||||
- Do NOT force use of a nickname or personalized greeting.
|
||||
</memory_onboarding>
|
||||
|
||||
<memory_guidelines>
|
||||
|
||||
184
app/agent/middleware/usage.py
Normal file
184
app/agent/middleware/usage.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UsageMiddleware(AgentMiddleware):
|
||||
"""记录模型调用 usage 信息并回传给外部会话。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
on_usage: Callable[[dict[str, Any]], None] | None = None,
|
||||
) -> None:
|
||||
self.on_usage = on_usage
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _lookup_int(cls, container: Any, *keys: str) -> int | None:
|
||||
if not container:
|
||||
return None
|
||||
|
||||
getter = getattr(container, "get", None)
|
||||
if callable(getter):
|
||||
for key in keys:
|
||||
value = getter(key)
|
||||
if value is not None:
|
||||
return cls._coerce_int(value)
|
||||
|
||||
for key in keys:
|
||||
value = getattr(container, key, None)
|
||||
if value is not None:
|
||||
return cls._coerce_int(value)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_model_name(cls, model: Any) -> str | None:
|
||||
return (
|
||||
getattr(model, "model", None)
|
||||
or getattr(model, "model_name", None)
|
||||
or getattr(model, "model_id", None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_context_window_tokens(cls, model: Any) -> int | None:
|
||||
profile = getattr(model, "profile", None)
|
||||
if not profile:
|
||||
return None
|
||||
return cls._lookup_int(profile, "max_input_tokens", "input_token_limit")
|
||||
|
||||
@classmethod
|
||||
def _extract_usage(cls, ai_message: AIMessage) -> dict[str, Any]:
|
||||
usage_metadata = getattr(ai_message, "usage_metadata", None)
|
||||
|
||||
input_tokens = cls._lookup_int(usage_metadata, "input_tokens")
|
||||
output_tokens = cls._lookup_int(usage_metadata, "output_tokens")
|
||||
total_tokens = cls._lookup_int(usage_metadata, "total_tokens")
|
||||
|
||||
response_metadata = getattr(ai_message, "response_metadata", None) or {}
|
||||
token_usage = (
|
||||
response_metadata.get("token_usage")
|
||||
or response_metadata.get("usage")
|
||||
or response_metadata.get("usage_metadata")
|
||||
or {}
|
||||
)
|
||||
|
||||
if input_tokens is None:
|
||||
input_tokens = cls._lookup_int(
|
||||
token_usage,
|
||||
"prompt_tokens",
|
||||
"input_tokens",
|
||||
)
|
||||
if input_tokens is None:
|
||||
input_tokens = cls._lookup_int(
|
||||
response_metadata,
|
||||
"prompt_token_count",
|
||||
"input_tokens",
|
||||
)
|
||||
|
||||
if output_tokens is None:
|
||||
output_tokens = cls._lookup_int(
|
||||
token_usage,
|
||||
"completion_tokens",
|
||||
"output_tokens",
|
||||
)
|
||||
if output_tokens is None:
|
||||
output_tokens = cls._lookup_int(
|
||||
response_metadata,
|
||||
"candidates_token_count",
|
||||
"output_tokens",
|
||||
)
|
||||
|
||||
if total_tokens is None:
|
||||
total_tokens = cls._lookup_int(token_usage, "total_tokens")
|
||||
if total_tokens is None:
|
||||
total_tokens = cls._lookup_int(response_metadata, "total_token_count")
|
||||
|
||||
has_usage = any(
|
||||
value is not None for value in (input_tokens, output_tokens, total_tokens)
|
||||
)
|
||||
resolved_input = input_tokens or 0
|
||||
resolved_output = output_tokens or 0
|
||||
resolved_total = (
|
||||
total_tokens
|
||||
if total_tokens is not None
|
||||
else resolved_input + resolved_output
|
||||
)
|
||||
|
||||
return {
|
||||
"has_usage": has_usage,
|
||||
"input_tokens": resolved_input,
|
||||
"output_tokens": resolved_output,
|
||||
"total_tokens": resolved_total,
|
||||
}
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
response = await handler(request)
|
||||
|
||||
if not callable(self.on_usage):
|
||||
return response
|
||||
|
||||
try:
|
||||
ai_message = next(
|
||||
(
|
||||
message
|
||||
for message in reversed(response.result)
|
||||
if isinstance(message, AIMessage)
|
||||
),
|
||||
None,
|
||||
)
|
||||
usage = (
|
||||
self._extract_usage(ai_message)
|
||||
if ai_message
|
||||
else {
|
||||
"has_usage": False,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
)
|
||||
context_window_tokens = self._extract_context_window_tokens(request.model)
|
||||
context_usage_ratio = None
|
||||
if context_window_tokens and usage["has_usage"]:
|
||||
context_usage_ratio = usage["input_tokens"] / context_window_tokens
|
||||
|
||||
self.on_usage(
|
||||
{
|
||||
"model": self._extract_model_name(request.model),
|
||||
"context_window_tokens": context_window_tokens,
|
||||
"context_usage_ratio": context_usage_ratio,
|
||||
**usage,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("记录模型 usage 失败: %s", e)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
__all__ = ["UsageMiddleware"]
|
||||
@@ -15,9 +15,12 @@ Core Capabilities:
|
||||
<communication>
|
||||
{verbose_spec}
|
||||
|
||||
- Tone: friendly, concise. Like a knowledgeable friend, not a corporate bot.
|
||||
- Use emojis sparingly (1-3 per response): greetings, completions, errors.
|
||||
- Tone: professional, concise, restrained.
|
||||
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
|
||||
- Prioritize task progress over conversation. Answer only what is necessary to move the task forward.
|
||||
- Do NOT flatter the user, praise the question, or use overly eager/service-oriented phrases.
|
||||
- Do NOT use emojis, exclamation marks, cute language, or excessive apology.
|
||||
- Prefer short declarative sentences. Default to one or two short paragraphs; use lists only when they improve scanability.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles/paths.
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
@@ -34,6 +37,7 @@ Core Capabilities:
|
||||
- NO filler phrases like "Let me help you", "Here are the results", "I found..." — skip all unnecessary preamble.
|
||||
- NO repeating what user said.
|
||||
- NO narrating your internal reasoning.
|
||||
- NO praise, emotional cushioning, or unnecessary politeness padding.
|
||||
- After task completion: one line summary only.
|
||||
- When error occurs: brief acknowledgment + suggestion, then move on.
|
||||
</response_format>
|
||||
@@ -56,6 +60,7 @@ Core Capabilities:
|
||||
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
|
||||
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
|
||||
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
|
||||
5. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
</media_management_rules>
|
||||
|
||||
<markdown_spec>
|
||||
|
||||
@@ -82,11 +82,13 @@ class PromptManager:
|
||||
verbose_spec = ""
|
||||
if not settings.AI_AGENT_VERBOSE:
|
||||
verbose_spec = (
|
||||
"\n\n[Important Instruction] STRICTLY ENFORCED: DO NOT output any conversational "
|
||||
"text, thinking processes, or explanations before or during tool calls. Call tools "
|
||||
"directly without any transitional phrases. "
|
||||
"You MUST remain completely silent until the task is completely finished. "
|
||||
"DO NOT output any content whatsoever until your final summary reply."
|
||||
"\n\n[Important Instruction] STRICTLY ENFORCED: "
|
||||
"If tools are needed, DO NOT output any conversational text, explanations, progress updates, "
|
||||
"or acknowledgements before the first tool call or between tool calls. "
|
||||
"Call tools directly without any transitional phrases. "
|
||||
"You MUST remain completely silent until all required tools have finished and you have the final result. "
|
||||
"Only then may you send one final user-facing reply. "
|
||||
"DO NOT output any intermediate content whatsoever."
|
||||
)
|
||||
|
||||
# MoviePilot系统信息
|
||||
@@ -193,18 +195,18 @@ class PromptManager:
|
||||
def _generate_button_choice_instructions(
|
||||
channel: MessageChannel = None,
|
||||
) -> str:
|
||||
if channel and ChannelCapabilityManager.supports_buttons(
|
||||
if (
|
||||
channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(channel):
|
||||
and ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
):
|
||||
return (
|
||||
"- User questions: If you need the user to choose from a few clear options, "
|
||||
"call `ask_user_choice` to send button options. After the user clicks a button, "
|
||||
"the selected value will come back as the user's next message. After calling this tool, "
|
||||
"wait for the user's selection instead of repeating the question in plain text."
|
||||
)
|
||||
return (
|
||||
"- User questions: When you truly need user input, ask briefly in plain text."
|
||||
)
|
||||
return "- User questions: When you truly need user input, ask briefly in plain text."
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
|
||||
@@ -82,8 +82,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
else:
|
||||
# 非VERBOSE,重置缓冲区从头更新,保持消息编辑能力
|
||||
self._stream_handler.reset()
|
||||
# 非VERBOSE:工具边界至少补一个换行,避免工具前后的文本直接连在一起
|
||||
if self._stream_handler.last_buffer_char not in ("", "\n"):
|
||||
self._stream_handler.emit("\n")
|
||||
else:
|
||||
# 未启用流式传输,不发送任何工具消息内容
|
||||
pass
|
||||
|
||||
@@ -47,13 +47,13 @@ class AddDownloadTool(MoviePilotTool):
|
||||
if torrent_urls:
|
||||
if len(torrent_urls) == 1:
|
||||
if self._is_torrent_ref(torrent_urls[0]):
|
||||
message = f"正在添加下载任务: 资源 {torrent_urls[0]}"
|
||||
message = f"添加下载任务: 资源 {torrent_urls[0]}"
|
||||
else:
|
||||
message = "正在添加下载任务: 磁力链接"
|
||||
message = "添加下载任务: 磁力链接"
|
||||
else:
|
||||
message = f"正在批量添加下载任务: 共 {len(torrent_urls)} 个资源"
|
||||
message = f"批量添加下载任务: 共 {len(torrent_urls)} 个资源"
|
||||
else:
|
||||
message = "正在添加下载任务"
|
||||
message = "添加下载任务"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
|
||||
|
||||
@@ -12,36 +12,74 @@ from app.schemas.types import MediaType
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
"""添加订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: str = Field(..., description="Release year of the media (required for accurate identification)")
|
||||
media_type: str = Field(...,
|
||||
description="Allowed values: movie, tv")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[int] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)")
|
||||
douban_id: Optional[str] = Field(None,
|
||||
description="Douban ID for precise media identification (optional, alternative to tmdb_id)")
|
||||
start_episode: Optional[int] = Field(None,
|
||||
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)")
|
||||
total_episode: Optional[int] = Field(None,
|
||||
description="Total number of episodes for TV shows (optional, will be auto-detected from TMDB if not specified)")
|
||||
quality: Optional[str] = Field(None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply (optional, can be obtained from query_rule_groups tool)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="List of site IDs to search from (optional, can be obtained from query_sites tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
title: str = Field(
|
||||
...,
|
||||
description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')",
|
||||
)
|
||||
year: str = Field(
|
||||
...,
|
||||
description="Release year of the media (required for accurate identification)",
|
||||
)
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
season: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Season number for TV shows (optional). If omitted, the subscription defaults to season 1 only. "
|
||||
"To subscribe multiple seasons or the full series, call this tool separately for each season."
|
||||
),
|
||||
)
|
||||
tmdb_id: Optional[int] = Field(
|
||||
None,
|
||||
description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)",
|
||||
)
|
||||
douban_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Douban ID for precise media identification (optional, alternative to tmdb_id)",
|
||||
)
|
||||
start_episode: Optional[int] = Field(
|
||||
None,
|
||||
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)",
|
||||
)
|
||||
total_episode: Optional[int] = Field(
|
||||
None,
|
||||
description="Total number of episodes for TV shows (optional, will be auto-detected from TMDB if not specified)",
|
||||
)
|
||||
quality: Optional[str] = Field(
|
||||
None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')",
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')",
|
||||
)
|
||||
effect: Optional[str] = Field(
|
||||
None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')",
|
||||
)
|
||||
filter_groups: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of filter rule group names to apply (optional, can be obtained from query_rule_groups tool)",
|
||||
)
|
||||
sites: Optional[List[int]] = Field(
|
||||
None,
|
||||
description="List of site IDs to search from (optional, can be obtained from query_sites tool)",
|
||||
)
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria. Supports advanced filtering options like quality, resolution, and effect filters using regular expressions."
|
||||
description: str = (
|
||||
"Add media subscription to create automated download rules for movies and TV shows. "
|
||||
"The system will automatically search and download new episodes or releases based on the subscription criteria. "
|
||||
"For TV shows, omitting `season` subscribes season 1 only by default; to subscribe multiple seasons or "
|
||||
"the full series, call this tool once per season. Supports advanced filtering options like quality, "
|
||||
"resolution, and effect filters using regular expressions."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AddSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -50,52 +88,72 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
year = kwargs.get("year", "")
|
||||
media_type = kwargs.get("media_type", "")
|
||||
season = kwargs.get("season")
|
||||
|
||||
message = f"正在添加订阅: {title}"
|
||||
|
||||
message = f"添加订阅: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
|
||||
elif media_type == "tv":
|
||||
message += " 第1季(默认)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[int] = None,
|
||||
douban_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
title: str,
|
||||
year: str,
|
||||
media_type: str,
|
||||
season: Optional[int] = None,
|
||||
tmdb_id: Optional[int] = None,
|
||||
douban_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
filter_groups: Optional[List[str]] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, douban_id={douban_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
|
||||
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
|
||||
f"effect={effect}, filter_groups={filter_groups}, sites={sites}"
|
||||
)
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
effective_season = (
|
||||
season
|
||||
if season is not None
|
||||
else 1
|
||||
if media_type_enum == MediaType.TV
|
||||
else None
|
||||
)
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
if start_episode is not None:
|
||||
subscribe_kwargs['start_episode'] = start_episode
|
||||
subscribe_kwargs["start_episode"] = start_episode
|
||||
if total_episode is not None:
|
||||
subscribe_kwargs['total_episode'] = total_episode
|
||||
subscribe_kwargs["total_episode"] = total_episode
|
||||
if quality:
|
||||
subscribe_kwargs['quality'] = quality
|
||||
subscribe_kwargs["quality"] = quality
|
||||
if resolution:
|
||||
subscribe_kwargs['resolution'] = resolution
|
||||
subscribe_kwargs["resolution"] = resolution
|
||||
if effect:
|
||||
subscribe_kwargs['effect'] = effect
|
||||
subscribe_kwargs["effect"] = effect
|
||||
if filter_groups:
|
||||
subscribe_kwargs['filter_groups'] = filter_groups
|
||||
subscribe_kwargs["filter_groups"] = filter_groups
|
||||
if sites:
|
||||
subscribe_kwargs['sites'] = sites
|
||||
subscribe_kwargs["sites"] = sites
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=media_type_enum,
|
||||
@@ -105,13 +163,21 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
doubanid=douban_id,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
**subscribe_kwargs
|
||||
**subscribe_kwargs,
|
||||
)
|
||||
if sid:
|
||||
if message and "已存在" in message:
|
||||
return f"订阅已存在:{title} ({year})。如需修改参数请先删除旧订阅。"
|
||||
result_msg = f"订阅已存在:{title} ({year})"
|
||||
if effective_season is not None:
|
||||
result_msg += f" 第{effective_season}季"
|
||||
result_msg += "。如需修改参数请先删除旧订阅。"
|
||||
return result_msg
|
||||
|
||||
result_msg = f"成功添加订阅:{title} ({year})"
|
||||
if effective_season is not None:
|
||||
result_msg += f" 第{effective_season}季"
|
||||
if season is None:
|
||||
result_msg += "(未指定季号,默认按第一季订阅)"
|
||||
if subscribe_kwargs:
|
||||
params = []
|
||||
if start_episode is not None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.interaction import (
|
||||
from app.chain.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
@@ -75,7 +75,7 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
message = kwargs.get("message", "") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送按钮选择: {message}"
|
||||
return f"发送按钮选择: {message}"
|
||||
|
||||
@staticmethod
|
||||
def _truncate_button_text(text: str, max_length: int) -> str:
|
||||
@@ -106,7 +106,7 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
):
|
||||
return f"当前渠道 {channel.value} 不支持按钮选择"
|
||||
|
||||
max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
|
||||
max_per_row = 1
|
||||
max_rows = ChannelCapabilityManager.get_max_button_rows(channel)
|
||||
max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
|
||||
max_options = max_per_row * max_rows
|
||||
|
||||
@@ -108,16 +108,16 @@ class BrowseWebpageTool(MoviePilotTool):
|
||||
url = kwargs.get("url", "")
|
||||
selector = kwargs.get("selector", "")
|
||||
action_messages = {
|
||||
"goto": f"正在打开网页: {url}",
|
||||
"get_content": "正在获取页面内容",
|
||||
"screenshot": "正在截取页面截图",
|
||||
"click": f"正在点击元素: {selector}",
|
||||
"fill": f"正在填写表单: {selector}",
|
||||
"select": f"正在选择选项: {selector}",
|
||||
"evaluate": "正在执行 JavaScript",
|
||||
"wait": f"正在等待元素: {selector}",
|
||||
"goto": f"打开网页: {url}",
|
||||
"get_content": "获取页面内容",
|
||||
"screenshot": "截取页面截图",
|
||||
"click": f"点击元素: {selector}",
|
||||
"fill": f"填写表单: {selector}",
|
||||
"select": f"选择选项: {selector}",
|
||||
"evaluate": "执行 JavaScript",
|
||||
"wait": f"等待元素: {selector}",
|
||||
}
|
||||
return action_messages.get(action, f"正在执行浏览器操作: {action}")
|
||||
return action_messages.get(action, f"执行浏览器操作: {action}")
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
||||
@@ -41,7 +41,7 @@ class DeleteDownloadTool(MoviePilotTool):
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
message = f"正在删除下载任务: {hash_value}"
|
||||
message = f"删除下载任务: {hash_value}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
|
||||
@@ -30,7 +30,7 @@ class DeleteDownloadHistoryTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
history_id = kwargs.get("history_id")
|
||||
return f"正在删除下载历史记录 ID: {history_id}"
|
||||
return f"删除下载历史记录 ID: {history_id}"
|
||||
|
||||
async def run(self, history_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
@@ -34,7 +34,7 @@ class DeleteSubscribeTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
return f"正在删除订阅 (ID: {subscribe_id})"
|
||||
return f"删除订阅 (ID: {subscribe_id})"
|
||||
|
||||
async def run(self, subscribe_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
@@ -30,7 +30,7 @@ class DeleteTransferHistoryTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
history_id = kwargs.get("history_id")
|
||||
return f"正在删除整理历史记录: ID={history_id}"
|
||||
return f"删除整理历史记录: ID={history_id}"
|
||||
|
||||
async def run(self, history_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
@@ -28,7 +28,7 @@ class EditFileTool(MoviePilotTool):
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在编辑文件: {file_name}"
|
||||
return f"编辑文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, old_text: str, new_text: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""执行Shell命令工具"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -9,6 +13,54 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 60
|
||||
MAX_TIMEOUT_SECONDS = 300
|
||||
MAX_OUTPUT_CHARS = 6000
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
|
||||
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CommandOutput:
|
||||
"""保存受限命令输出,避免大输出一次性进入内存。"""
|
||||
|
||||
limit: int
|
||||
stdout_chunks: list[str] = field(default_factory=list)
|
||||
stderr_chunks: list[str] = field(default_factory=list)
|
||||
captured_chars: int = 0
|
||||
truncated: bool = False
|
||||
|
||||
def append(self, stream_name: str, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
|
||||
remaining = self.limit - self.captured_chars
|
||||
if remaining <= 0:
|
||||
self.truncated = True
|
||||
return
|
||||
|
||||
captured = text[:remaining]
|
||||
if stream_name == "stdout":
|
||||
self.stdout_chunks.append(captured)
|
||||
else:
|
||||
self.stderr_chunks.append(captured)
|
||||
|
||||
self.captured_chars += len(captured)
|
||||
if len(text) > remaining:
|
||||
self.truncated = True
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
return "".join(self.stdout_chunks).strip()
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
return "".join(self.stderr_chunks).strip()
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
|
||||
@@ -23,14 +75,160 @@ class ExecuteCommandInput(BaseModel):
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
name: str = "execute_command"
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
description: str = (
|
||||
"Safely execute shell commands on the server. Useful for system "
|
||||
"maintenance, checking status, or running custom scripts. Includes "
|
||||
"timeout, concurrency, and hard output limits."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行系统命令: {command}"
|
||||
return f"执行系统命令: {command}"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
|
||||
"""限制命令最长运行时间,避免 Agent 传入过大的 timeout。"""
|
||||
try:
|
||||
normalized = int(timeout or DEFAULT_TIMEOUT_SECONDS)
|
||||
except (TypeError, ValueError):
|
||||
normalized = DEFAULT_TIMEOUT_SECONDS
|
||||
|
||||
if normalized <= 0:
|
||||
return DEFAULT_TIMEOUT_SECONDS, "timeout 参数无效,已使用默认 60 秒"
|
||||
if normalized > MAX_TIMEOUT_SECONDS:
|
||||
return (
|
||||
MAX_TIMEOUT_SECONDS,
|
||||
f"timeout 参数超过上限,已从 {normalized} 秒限制为 {MAX_TIMEOUT_SECONDS} 秒",
|
||||
)
|
||||
return normalized, None
|
||||
|
||||
@staticmethod
|
||||
def _subprocess_kwargs() -> dict:
|
||||
"""为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。"""
|
||||
kwargs = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
"stderr": asyncio.subprocess.PIPE,
|
||||
}
|
||||
if os.name == "posix":
|
||||
kwargs["start_new_session"] = True
|
||||
elif os.name == "nt":
|
||||
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
async def _read_stream(
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
output: _CommandOutput,
|
||||
limit_reached: asyncio.Event,
|
||||
) -> None:
|
||||
"""按块读取输出,达到上限后通知主流程终止命令。"""
|
||||
while True:
|
||||
chunk = await stream.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
continue
|
||||
|
||||
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
# 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
|
||||
"""向进程组发送终止信号;不支持进程组的平台回退为单进程终止。"""
|
||||
try:
|
||||
if os.name == "posix":
|
||||
os.killpg(process.pid, sig)
|
||||
elif sig == getattr(signal, "SIGKILL", None):
|
||||
process.kill()
|
||||
else:
|
||||
process.terminate()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_process(
|
||||
cls,
|
||||
process: asyncio.subprocess.Process,
|
||||
wait_task: asyncio.Task,
|
||||
) -> None:
|
||||
"""先温和终止,失败后强杀,避免超时 shell 遗留子进程。"""
|
||||
if wait_task.done():
|
||||
return
|
||||
|
||||
cls._terminate_process(process, signal.SIGTERM)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=KILL_GRACE_SECONDS
|
||||
)
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
kill_signal = getattr(signal, "SIGKILL", signal.SIGTERM)
|
||||
cls._terminate_process(process, kill_signal)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=KILL_GRACE_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("命令进程强制清理超时: pid=%s", process.pid)
|
||||
|
||||
@staticmethod
|
||||
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
|
||||
"""等待输出读取任务退出,异常只记录不影响工具返回。"""
|
||||
if not reader_tasks:
|
||||
return
|
||||
done, pending = await asyncio.wait(reader_tasks, timeout=1)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
results = await asyncio.gather(*done, *pending, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception) and not isinstance(
|
||||
result, asyncio.CancelledError
|
||||
):
|
||||
logger.debug("命令输出读取任务异常: %s", result)
|
||||
|
||||
@staticmethod
|
||||
def _format_result(
|
||||
*,
|
||||
exit_code: Optional[int],
|
||||
output: _CommandOutput,
|
||||
timeout: int,
|
||||
timed_out: bool,
|
||||
output_limited: bool,
|
||||
timeout_note: Optional[str],
|
||||
) -> str:
|
||||
if timed_out:
|
||||
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
|
||||
elif output_limited:
|
||||
result = (
|
||||
f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符,"
|
||||
f"已截断并终止进程,退出码: {exit_code})"
|
||||
)
|
||||
else:
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
|
||||
if timeout_note:
|
||||
result += f"\n\n提示:\n{timeout_note}"
|
||||
if output.stdout:
|
||||
result += f"\n\n标准输出:\n{output.stdout}"
|
||||
if output.stderr:
|
||||
result += f"\n\n错误输出:\n{output.stderr}"
|
||||
if output.truncated:
|
||||
result += "\n\n...(输出内容过长,已截断)"
|
||||
if not output.stdout and not output.stderr:
|
||||
result += "\n\n(无输出内容)"
|
||||
return result
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(
|
||||
@@ -50,46 +248,57 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
try:
|
||||
async with _command_semaphore:
|
||||
# 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, **self._subprocess_kwargs()
|
||||
)
|
||||
output = _CommandOutput(limit=MAX_OUTPUT_CHARS)
|
||||
limit_reached = asyncio.Event()
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
limit_task = asyncio.create_task(limit_reached.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stdout, "stdout", output, limit_reached
|
||||
)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stderr, "stderr", output, limit_reached
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
timed_out = False
|
||||
output_limited = False
|
||||
done, _ = await asyncio.wait(
|
||||
{wait_task, limit_task},
|
||||
timeout=normalized_timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_str = stderr.decode("utf-8", errors="replace").strip()
|
||||
exit_code = process.returncode
|
||||
if wait_task not in done:
|
||||
if limit_task in done:
|
||||
output_limited = True
|
||||
else:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
if stdout_str:
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
limit_task.cancel()
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时处理
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return f"命令执行超时 (限制: {timeout}秒)"
|
||||
return self._format_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
output_limited=output_limited,
|
||||
timeout_note=timeout_note,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
|
||||
@@ -62,7 +62,7 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
"douban_hot": "豆瓣热门",
|
||||
"douban_movie_hot": "豆瓣热门电影",
|
||||
"douban_tv_hot": "豆瓣热门电视剧",
|
||||
"douban_movie_showing": "豆瓣正在热映",
|
||||
"douban_movie_showing": "豆瓣热映",
|
||||
"douban_movies": "豆瓣最新电影",
|
||||
"douban_tvs": "豆瓣最新电视剧",
|
||||
"douban_movie_top250": "豆瓣电影TOP250",
|
||||
@@ -73,7 +73,7 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
}
|
||||
source_desc = source_map.get(source, source)
|
||||
|
||||
message = f"正在获取推荐: {source_desc}"
|
||||
message = f"获取推荐: {source_desc}"
|
||||
if media_type != "all":
|
||||
message += f" [{media_type}]"
|
||||
message += f" (第{page}页)"
|
||||
|
||||
@@ -53,7 +53,7 @@ class GetSearchResultsTool(MoviePilotTool):
|
||||
args_schema: Type[BaseModel] = GetSearchResultsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return "正在获取搜索结果"
|
||||
return "获取搜索结果"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
||||
@@ -32,7 +32,7 @@ class ListDirectoryTool(MoviePilotTool):
|
||||
path = kwargs.get("path", "")
|
||||
storage = kwargs.get("storage", "local")
|
||||
|
||||
message = f"正在查询目录: {path}"
|
||||
message = f"查询目录: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class ListSlashCommandsTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询所有可用命令"
|
||||
return "查询所有可用命令"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
@@ -55,7 +55,7 @@ class ModifyDownloadTool(MoviePilotTool):
|
||||
tags = kwargs.get("tags")
|
||||
downloader = kwargs.get("downloader")
|
||||
|
||||
parts = [f"正在修改下载任务: {hash_value}"]
|
||||
parts = [f"修改下载任务: {hash_value}"]
|
||||
if action == "start":
|
||||
parts.append("操作: 开始下载")
|
||||
elif action == "stop":
|
||||
|
||||
@@ -31,7 +31,7 @@ class QueryCustomIdentifiersTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询自定义识别词"
|
||||
return "查询自定义识别词"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
@@ -32,7 +32,7 @@ class QueryDirectorySettingsTool(MoviePilotTool):
|
||||
storage_type = kwargs.get("storage_type", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询目录配置"]
|
||||
parts = ["查询目录配置"]
|
||||
|
||||
if directory_type != "all":
|
||||
type_map = {"download": "下载目录", "library": "媒体库目录"}
|
||||
|
||||
@@ -36,7 +36,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
查询所有状态的任务(包括下载中和已完成的任务)
|
||||
"""
|
||||
all_torrents = []
|
||||
# 查询正在下载的任务
|
||||
# 查询下载的任务
|
||||
downloading_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.DOWNLOADING
|
||||
@@ -71,7 +71,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
hash_value = kwargs.get("hash")
|
||||
title = kwargs.get("title")
|
||||
|
||||
parts = ["正在查询下载任务"]
|
||||
parts = ["查询下载任务"]
|
||||
|
||||
if downloader:
|
||||
parts.append(f"下载器: {downloader}")
|
||||
|
||||
@@ -23,7 +23,7 @@ class QueryDownloadersTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询下载器配置"
|
||||
return "查询下载器配置"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
@@ -29,7 +29,7 @@ class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
season = kwargs.get("season")
|
||||
episode_group = kwargs.get("episode_group")
|
||||
|
||||
message = f"正在查询剧集上映时间: TMDB ID {tmdb_id} 第{season}季"
|
||||
message = f"查询剧集上映时间: TMDB ID {tmdb_id} 第{season}季"
|
||||
if episode_group:
|
||||
message += f" (剧集组: {episode_group})"
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询已安装插件"
|
||||
return "查询已安装插件"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
@@ -93,11 +93,11 @@ class QueryLibraryExistsTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type")
|
||||
|
||||
if tmdb_id:
|
||||
message = f"正在查询媒体库: TMDB={tmdb_id}"
|
||||
message = f"查询媒体库: TMDB={tmdb_id}"
|
||||
elif douban_id:
|
||||
message = f"正在查询媒体库: 豆瓣={douban_id}"
|
||||
message = f"查询媒体库: 豆瓣={douban_id}"
|
||||
else:
|
||||
message = "正在查询媒体库"
|
||||
message = "查询媒体库"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
return message
|
||||
|
||||
@@ -39,7 +39,7 @@ class QueryLibraryLatestTool(MoviePilotTool):
|
||||
server = kwargs.get("server")
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询媒体服务器最近入库影片"]
|
||||
parts = ["查询媒体服务器最近入库影片"]
|
||||
|
||||
if server:
|
||||
parts.append(f"服务器: {server}")
|
||||
|
||||
@@ -29,8 +29,8 @@ class QueryMediaDetailTool(MoviePilotTool):
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
douban_id = kwargs.get("douban_id")
|
||||
if tmdb_id:
|
||||
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
|
||||
return f"正在查询媒体详情: 豆瓣 ID {douban_id}"
|
||||
return f"查询媒体详情: TMDB ID {tmdb_id}"
|
||||
return f"查询媒体详情: 豆瓣 ID {douban_id}"
|
||||
|
||||
async def run(self, media_type: str, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}")
|
||||
|
||||
@@ -40,8 +40,8 @@ class QueryPluginCapabilitiesTool(MoviePilotTool):
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id")
|
||||
if plugin_id:
|
||||
return f"正在查询插件 {plugin_id} 的能力"
|
||||
return "正在查询所有插件的能力"
|
||||
return f"查询插件 {plugin_id} 的能力"
|
||||
return "查询所有插件的能力"
|
||||
|
||||
async def run(self, plugin_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
|
||||
@@ -39,7 +39,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = [f"正在查询热门订阅 [{media_type}]"]
|
||||
parts = [f"查询热门订阅 [{media_type}]"]
|
||||
|
||||
if min_sub:
|
||||
parts.append(f"最少订阅: {min_sub}")
|
||||
|
||||
@@ -22,7 +22,7 @@ class QueryRuleGroupsTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
return "正在查询所有规则组"
|
||||
return "查询所有规则组"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
@@ -22,7 +22,7 @@ class QuerySchedulersTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询定时服务"
|
||||
return "查询定时服务"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
@@ -40,7 +40,7 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
site_id = kwargs.get("site_id")
|
||||
workdate = kwargs.get("workdate")
|
||||
|
||||
message = f"正在查询站点 #{site_id} 的用户数据"
|
||||
message = f"查询站点 #{site_id} 的用户数据"
|
||||
if workdate:
|
||||
message += f" (日期: {workdate})"
|
||||
else:
|
||||
|
||||
@@ -37,7 +37,7 @@ class QuerySitesTool(MoviePilotTool):
|
||||
status = kwargs.get("status", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询站点"]
|
||||
parts = ["查询站点"]
|
||||
|
||||
if status != "all":
|
||||
status_map = {"active": "已启用", "inactive": "已禁用"}
|
||||
|
||||
@@ -44,7 +44,7 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
name = kwargs.get("name")
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询订阅历史"]
|
||||
parts = ["查询订阅历史"]
|
||||
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
@@ -34,7 +34,7 @@ class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = ["正在查询订阅分享"]
|
||||
parts = ["查询订阅分享"]
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
@@ -79,7 +79,7 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询订阅"]
|
||||
parts = ["查询订阅"]
|
||||
|
||||
# 根据状态过滤条件生成提示
|
||||
if status != "all":
|
||||
|
||||
@@ -33,7 +33,7 @@ class QueryTransferHistoryTool(MoviePilotTool):
|
||||
status = kwargs.get("status", "all")
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询整理历史"]
|
||||
parts = ["查询整理历史"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
|
||||
@@ -30,7 +30,7 @@ class QueryWorkflowsTool(MoviePilotTool):
|
||||
name = kwargs.get("name")
|
||||
trigger_type = kwargs.get("trigger_type", "all")
|
||||
|
||||
parts = ["正在查询工作流"]
|
||||
parts = ["查询工作流"]
|
||||
|
||||
if state != "all":
|
||||
state_map = {"W": "等待", "R": "运行中", "P": "暂停", "S": "成功", "F": "失败"}
|
||||
|
||||
@@ -29,7 +29,7 @@ class ReadFileTool(MoviePilotTool):
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在读取文件: {file_name}"
|
||||
return f"读取文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, start_line: Optional[int] = None,
|
||||
end_line: Optional[int] = None, **kwargs) -> str:
|
||||
|
||||
@@ -33,13 +33,13 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
path = kwargs.get("path")
|
||||
|
||||
if path:
|
||||
message = f"正在识别文件媒体信息: {path}"
|
||||
message = f"识别文件媒体信息: {path}"
|
||||
elif title:
|
||||
message = f"正在识别种子媒体信息: {title}"
|
||||
message = f"识别种子媒体信息: {title}"
|
||||
if subtitle:
|
||||
message += f" ({subtitle})"
|
||||
else:
|
||||
message = "正在识别媒体信息"
|
||||
message = "识别媒体信息"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class RunSchedulerTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据运行参数生成友好的提示消息"""
|
||||
job_id = kwargs.get("job_id", "")
|
||||
return f"正在运行定时服务 (ID: {job_id})"
|
||||
return f"运行定时服务 (ID: {job_id})"
|
||||
|
||||
async def run(self, job_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: job_id={job_id}")
|
||||
|
||||
@@ -45,7 +45,7 @@ class RunSlashCommandTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行命令: {command}"
|
||||
return f"执行命令: {command}"
|
||||
|
||||
async def run(self, command: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}")
|
||||
|
||||
@@ -38,7 +38,7 @@ class RunWorkflowTool(MoviePilotTool):
|
||||
workflow_id = kwargs.get("workflow_id")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
message = f"正在执行工作流: {workflow_id}"
|
||||
message = f"执行工作流: {workflow_id}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
|
||||
@@ -47,7 +47,7 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
storage = kwargs.get("storage", "local")
|
||||
overwrite = kwargs.get("overwrite", False)
|
||||
|
||||
message = f"正在刮削媒体元数据: {path}"
|
||||
message = f"刮削媒体元数据: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
if overwrite:
|
||||
|
||||
@@ -34,7 +34,7 @@ class SearchMediaTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type")
|
||||
season = kwargs.get("season")
|
||||
|
||||
message = f"正在搜索媒体: {title}"
|
||||
message = f"搜索媒体: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
|
||||
@@ -24,7 +24,7 @@ class SearchPersonTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
name = kwargs.get("name", "")
|
||||
return f"正在搜索人物: {name}"
|
||||
return f"搜索人物: {name}"
|
||||
|
||||
async def run(self, name: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: name={name}")
|
||||
|
||||
@@ -29,7 +29,7 @@ class SearchPersonCreditsTool(MoviePilotTool):
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
person_id = kwargs.get("person_id", "")
|
||||
source = kwargs.get("source", "")
|
||||
return f"正在搜索人物参演作品: {source} ID {person_id}"
|
||||
return f"搜索人物参演作品: {source} ID {person_id}"
|
||||
|
||||
async def run(self, person_id: int, source: str, page: Optional[int] = 1, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: person_id={person_id}, source={source}, page={page}")
|
||||
|
||||
@@ -32,7 +32,7 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
manual = kwargs.get("manual", False)
|
||||
|
||||
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
|
||||
message = f"搜索订阅 #{subscribe_id} 的缺失剧集"
|
||||
if manual:
|
||||
message += "(手动搜索)"
|
||||
|
||||
|
||||
@@ -41,11 +41,11 @@ class SearchTorrentsTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type")
|
||||
|
||||
if tmdb_id:
|
||||
message = f"正在搜索种子: TMDB={tmdb_id}"
|
||||
message = f"搜索种子: TMDB={tmdb_id}"
|
||||
elif douban_id:
|
||||
message = f"正在搜索种子: 豆瓣={douban_id}"
|
||||
message = f"搜索种子: 豆瓣={douban_id}"
|
||||
else:
|
||||
message = "正在搜索种子"
|
||||
message = "搜索种子"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
return message
|
||||
|
||||
@@ -41,7 +41,7 @@ class SearchWebTool(MoviePilotTool):
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 20)
|
||||
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
return f"搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 20, **kwargs) -> str:
|
||||
"""
|
||||
|
||||
@@ -55,7 +55,7 @@ class SendLocalFileTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在发送本地附件: {file_name}"
|
||||
return f"发送本地附件: {file_name}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
||||
@@ -52,12 +52,12 @@ class SendMessageTool(MoviePilotTool):
|
||||
message = message[:50] + "..."
|
||||
|
||||
if title and image_url:
|
||||
return f"正在发送图文消息: [{title}] {message}"
|
||||
return f"发送图文消息: [{title}] {message}"
|
||||
if title:
|
||||
return f"正在发送消息: [{title}] {message}"
|
||||
return f"发送消息: [{title}] {message}"
|
||||
if image_url:
|
||||
return f"正在发送图片消息: {message}"
|
||||
return f"正在发送消息: {message}"
|
||||
return f"发送图片消息: {message}"
|
||||
return f"发送消息: {message}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
||||
@@ -41,7 +41,7 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
message = kwargs.get("message") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送语音回复: {message}"
|
||||
return f"发送语音回复: {message}"
|
||||
|
||||
def _supports_real_voice_reply(self) -> bool:
|
||||
channel = self._channel or ""
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestSiteTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据测试参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
return f"正在测试站点连通性: {site_identifier}"
|
||||
return f"测试站点连通性: {site_identifier}"
|
||||
|
||||
async def run(self, site_identifier: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}")
|
||||
|
||||
@@ -68,7 +68,7 @@ class TransferFileTool(MoviePilotTool):
|
||||
transfer_type = kwargs.get("transfer_type")
|
||||
background = kwargs.get("background", False)
|
||||
|
||||
message = f"正在整理文件: {file_path}"
|
||||
message = f"整理文件: {file_path}"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if transfer_type:
|
||||
|
||||
@@ -57,7 +57,7 @@ class UpdateCustomIdentifiersTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
identifiers = kwargs.get("identifiers", [])
|
||||
return f"正在更新自定义识别词(共 {len(identifiers)} 条规则)"
|
||||
return f"更新自定义识别词(共 {len(identifiers)} 条规则)"
|
||||
|
||||
async def run(self, identifiers: List[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
|
||||
@@ -95,8 +95,8 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新站点 #{site_id}"
|
||||
return f"更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"更新站点 #{site_id}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
||||
@@ -41,7 +41,7 @@ class UpdateSiteCookieTool(MoviePilotTool):
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
message = f"更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
if two_step_code:
|
||||
message += " [需要两步验证]"
|
||||
|
||||
|
||||
@@ -117,8 +117,8 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新订阅 #{subscribe_id}"
|
||||
return f"更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"更新订阅 #{subscribe_id}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
||||
@@ -27,7 +27,7 @@ class WriteFileTool(MoviePilotTool):
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在写入文件: {file_name}"
|
||||
return f"写入文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, content: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
@@ -30,3 +30,5 @@ api_router.include_router(recommend.router, prefix="/recommend", tags=["recommen
|
||||
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
|
||||
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
|
||||
api_router.include_router(openai.router, prefix="/openai/v1", tags=["openai"])
|
||||
api_router.include_router(anthropic.router, prefix="/anthropic/v1", tags=["anthropic"])
|
||||
|
||||
158
app/api/endpoints/anthropic.py
Normal file
158
app/api/endpoints/anthropic.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Header, Security
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.api.endpoints.openai import (
|
||||
MODEL_ID,
|
||||
_CollectingMoviePilotAgent,
|
||||
_error_response as _openai_error_response,
|
||||
)
|
||||
from app.api.openai_utils import build_anthropic_messages, build_prompt, build_session_id
|
||||
from app.core.config import settings
|
||||
from app.core.security import anthropic_api_key_header
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
SESSION_PREFIX = "anthropic:"
|
||||
|
||||
|
||||
def _anthropic_error_response(
|
||||
message: str,
|
||||
status_code: int,
|
||||
error_type: str = "invalid_request_error",
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=schemas.AnthropicErrorResponse(
|
||||
error=schemas.AnthropicErrorDetail(type=error_type, message=message)
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
def _check_auth(api_key: Optional[str]) -> Optional[JSONResponse]:
|
||||
if not api_key or api_key != settings.API_TOKEN:
|
||||
return _anthropic_error_response(
|
||||
"invalid x-api-key",
|
||||
401,
|
||||
error_type="authentication_error",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _stream_anthropic_response(
|
||||
agent: _CollectingMoviePilotAgent,
|
||||
prompt: str,
|
||||
images: List[str],
|
||||
) -> AsyncIterator[str]:
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
if hasattr(agent.stream_handler, "bind_queue"):
|
||||
agent.stream_handler.bind_queue(event_queue)
|
||||
|
||||
message_id = f"msg_{uuid.uuid4().hex}"
|
||||
|
||||
async def _run_agent():
|
||||
try:
|
||||
await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
await event_queue.put({"error": str(exc)})
|
||||
finally:
|
||||
await event_queue.put(None)
|
||||
|
||||
task = asyncio.create_task(_run_agent())
|
||||
try:
|
||||
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': MODEL_ID, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}, ensure_ascii=False)}\n\n"
|
||||
while True:
|
||||
item = await event_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("error"):
|
||||
raise RuntimeError(str(item["error"]))
|
||||
text = str(item or "")
|
||||
if not text:
|
||||
continue
|
||||
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': text}}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: message_delta\ndata: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 0}}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'}, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/messages", summary="Anthropic compatible messages", response_model=schemas.AnthropicMessagesResponse)
|
||||
async def messages(
|
||||
payload: schemas.AnthropicMessagesRequest,
|
||||
x_api_key: Optional[str] = Security(anthropic_api_key_header),
|
||||
anthropic_version: Optional[str] = Header(default=None, alias="anthropic-version"),
|
||||
):
|
||||
auth_error = _check_auth(x_api_key)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return _anthropic_error_response(
|
||||
"MoviePilot AI agent is disabled.",
|
||||
503,
|
||||
error_type="api_error",
|
||||
)
|
||||
|
||||
normalized_messages = build_anthropic_messages(payload.system, payload.messages)
|
||||
try:
|
||||
prompt, images = build_prompt(normalized_messages, use_server_session=False)
|
||||
except ValueError as exc:
|
||||
return _anthropic_error_response(str(exc), 400)
|
||||
|
||||
session_seed = anthropic_version or "anthropic"
|
||||
session_id = build_session_id(f"{session_seed}:{uuid.uuid4().hex}", SESSION_PREFIX)
|
||||
agent = _CollectingMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=session_id,
|
||||
channel=MessageChannel.Web.value,
|
||||
source="anthropic",
|
||||
username="anthropic-client",
|
||||
stream_mode=payload.stream,
|
||||
)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
_stream_anthropic_response(agent=agent, prompt=prompt, images=images),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
return _anthropic_error_response(str(exc), 500, error_type="api_error")
|
||||
|
||||
content = "\n\n".join(
|
||||
message.strip()
|
||||
for message in agent.collected_messages
|
||||
if message and message.strip()
|
||||
).strip()
|
||||
if not content and result:
|
||||
content = str(result).strip()
|
||||
if not content:
|
||||
content = "未获得有效回复。"
|
||||
|
||||
return schemas.AnthropicMessagesResponse(
|
||||
id=f"msg_{uuid.uuid4().hex}",
|
||||
content=[schemas.AnthropicTextBlock(text=content)],
|
||||
model=MODEL_ID,
|
||||
)
|
||||
426
app/api/endpoints/openai.py
Normal file
426
app/api/endpoints/openai.py
Normal file
@@ -0,0 +1,426 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Request, Security
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
from app import schemas
|
||||
from app.api.openai_utils import (
|
||||
build_completion_payload,
|
||||
build_prompt,
|
||||
build_responses_input,
|
||||
build_session_id,
|
||||
)
|
||||
from app.agent import MoviePilotAgent, StreamingHandler
|
||||
from app.core.config import settings
|
||||
from app.core.security import openai_bearer_scheme
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MODEL_ID = "moviepilot-agent"
|
||||
SESSION_PREFIX = "openai:"
|
||||
|
||||
|
||||
class _CollectingMoviePilotAgent(MoviePilotAgent):
|
||||
"""
|
||||
捕获 Agent 最终输出,避免再通过消息渠道二次发送。
|
||||
"""
|
||||
|
||||
def __init__(self, *args, stream_mode: bool = False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.collected_messages: List[str] = []
|
||||
self.stream_mode = stream_mode
|
||||
if stream_mode:
|
||||
self.stream_handler = _OpenAIStreamingHandler()
|
||||
|
||||
def _should_stream(self) -> bool:
|
||||
return self.stream_mode
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = ""):
|
||||
text = (message or "").strip()
|
||||
if title and text:
|
||||
text = f"{title}\n{text}"
|
||||
elif title:
|
||||
text = title.strip()
|
||||
if text:
|
||||
self.collected_messages.append(text)
|
||||
if self.stream_mode:
|
||||
self.stream_handler.emit(text)
|
||||
|
||||
async def _save_agent_message_to_db(self, message: str, title: str = ""):
|
||||
return None
|
||||
|
||||
|
||||
class _OpenAIStreamingHandler(StreamingHandler):
|
||||
"""
|
||||
将 Agent 流式输出转发到 OpenAI SSE 队列,不向站内消息系统落消息。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._event_queue: Optional[asyncio.Queue] = None
|
||||
|
||||
def bind_queue(self, queue: asyncio.Queue):
|
||||
self._event_queue = queue
|
||||
|
||||
def emit(self, token: str):
|
||||
super().emit(token)
|
||||
if token and self._event_queue is not None:
|
||||
self._event_queue.put_nowait(token)
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
channel: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
title: str = "",
|
||||
):
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._user_id = user_id
|
||||
self._username = username
|
||||
self._title = title
|
||||
self._streaming_enabled = True
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
self._max_message_length = 0
|
||||
|
||||
async def stop_streaming(self) -> Tuple[bool, str]:
|
||||
if not self._streaming_enabled:
|
||||
return False, ""
|
||||
self._streaming_enabled = False
|
||||
with self._lock:
|
||||
final_text = self._buffer
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
return True, final_text
|
||||
|
||||
|
||||
def _sse_payload(data: dict) -> str:
|
||||
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
async def _stream_response(
|
||||
agent: _CollectingMoviePilotAgent,
|
||||
prompt: str,
|
||||
images: List[str],
|
||||
) -> AsyncIterator[str]:
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
if isinstance(agent.stream_handler, _OpenAIStreamingHandler):
|
||||
agent.stream_handler.bind_queue(event_queue)
|
||||
|
||||
created = int(time.time())
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
finished = False
|
||||
|
||||
async def _run_agent():
|
||||
try:
|
||||
await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
await event_queue.put({"error": str(exc)})
|
||||
finally:
|
||||
await event_queue.put(None)
|
||||
|
||||
task = asyncio.create_task(_run_agent())
|
||||
|
||||
try:
|
||||
yield _sse_payload(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
while True:
|
||||
item = await event_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("error"):
|
||||
raise RuntimeError(str(item["error"]))
|
||||
text = str(item or "")
|
||||
if not text:
|
||||
continue
|
||||
yield _sse_payload(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": text},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
finished = True
|
||||
yield _sse_payload(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
elif finished:
|
||||
await task
|
||||
|
||||
|
||||
def _error_response(
|
||||
message: str,
|
||||
status_code: int,
|
||||
error_type: str = "invalid_request_error",
|
||||
code: Optional[str] = None,
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=schemas.OpenAIErrorResponse(
|
||||
error=schemas.OpenAIErrorDetail(
|
||||
message=message,
|
||||
type=error_type,
|
||||
code=code,
|
||||
)
|
||||
).model_dump(),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def _check_auth(
|
||||
credentials: Optional[HTTPAuthorizationCredentials],
|
||||
) -> Optional[JSONResponse]:
|
||||
if not credentials or credentials.scheme.lower() != "bearer":
|
||||
return _error_response(
|
||||
"Invalid bearer token.",
|
||||
401,
|
||||
error_type="authentication_error",
|
||||
code="invalid_api_key",
|
||||
)
|
||||
if credentials.credentials != settings.API_TOKEN:
|
||||
return _error_response(
|
||||
"Invalid bearer token.",
|
||||
401,
|
||||
error_type="authentication_error",
|
||||
code="invalid_api_key",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/models", summary="OpenAI compatible models", response_model=schemas.OpenAIModelListResponse)
|
||||
async def list_models(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
||||
):
|
||||
auth_error = _check_auth(credentials)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
now = int(time.time())
|
||||
return schemas.OpenAIModelListResponse(
|
||||
data=[schemas.OpenAIModelInfo(id=MODEL_ID, created=now)]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completions",
|
||||
summary="OpenAI compatible chat completions",
|
||||
response_model=schemas.OpenAIChatCompletionResponse,
|
||||
)
|
||||
async def chat_completions(
|
||||
payload: schemas.OpenAIChatCompletionsRequest,
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
||||
):
|
||||
auth_error = _check_auth(credentials)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return _error_response(
|
||||
"MoviePilot AI agent is disabled.",
|
||||
503,
|
||||
error_type="server_error",
|
||||
code="ai_agent_disabled",
|
||||
)
|
||||
|
||||
if not payload.messages:
|
||||
return _error_response(
|
||||
"`messages` must be a non-empty array.",
|
||||
400,
|
||||
code="invalid_messages",
|
||||
)
|
||||
|
||||
session_key = (
|
||||
str(payload.user or "").strip()
|
||||
or str(request.headers.get("x-session-id") or "").strip()
|
||||
or str(uuid.uuid4())
|
||||
)
|
||||
use_server_session = bool(
|
||||
str(payload.user or "").strip()
|
||||
or str(request.headers.get("x-session-id") or "").strip()
|
||||
)
|
||||
|
||||
try:
|
||||
prompt, images = build_prompt(payload.messages, use_server_session=use_server_session)
|
||||
except ValueError as exc:
|
||||
return _error_response(str(exc), 400, code="invalid_messages")
|
||||
|
||||
session_id = build_session_id(session_key, SESSION_PREFIX)
|
||||
username = str(payload.user or "openai-client")
|
||||
agent = _CollectingMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=session_key,
|
||||
channel=MessageChannel.Web.value,
|
||||
source="openai",
|
||||
username=username,
|
||||
stream_mode=payload.stream,
|
||||
)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
_stream_response(agent=agent, prompt=prompt, images=images),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
return _error_response(
|
||||
str(exc),
|
||||
500,
|
||||
error_type="server_error",
|
||||
code="agent_execution_failed",
|
||||
)
|
||||
|
||||
content = "\n\n".join(
|
||||
message.strip()
|
||||
for message in agent.collected_messages
|
||||
if message and message.strip()
|
||||
).strip()
|
||||
if not content and result:
|
||||
content = str(result).strip()
|
||||
if not content:
|
||||
content = "未获得有效回复。"
|
||||
|
||||
return JSONResponse(content=build_completion_payload(content, MODEL_ID))
|
||||
|
||||
|
||||
@router.post("/responses", summary="OpenAI compatible responses", response_model=schemas.OpenAIResponsesResponse)
|
||||
async def responses(
|
||||
payload: schemas.OpenAIResponsesRequest,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
||||
):
|
||||
auth_error = _check_auth(credentials)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return _error_response(
|
||||
"MoviePilot AI agent is disabled.",
|
||||
503,
|
||||
error_type="server_error",
|
||||
code="ai_agent_disabled",
|
||||
)
|
||||
|
||||
if payload.stream:
|
||||
return _error_response(
|
||||
"Streaming is not supported for /responses yet.",
|
||||
400,
|
||||
code="unsupported_stream",
|
||||
)
|
||||
|
||||
normalized_messages = build_responses_input(payload.input, instructions=payload.instructions)
|
||||
if not normalized_messages:
|
||||
return _error_response(
|
||||
"`input` must include at least one usable message.",
|
||||
400,
|
||||
code="invalid_input",
|
||||
)
|
||||
|
||||
try:
|
||||
prompt, images = build_prompt(normalized_messages, use_server_session=bool(payload.user))
|
||||
except ValueError as exc:
|
||||
return _error_response(str(exc), 400, code="invalid_input")
|
||||
|
||||
session_key = str(payload.user or uuid.uuid4())
|
||||
session_id = build_session_id(session_key, SESSION_PREFIX)
|
||||
agent = _CollectingMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=session_key,
|
||||
channel=MessageChannel.Web.value,
|
||||
source="openai.responses",
|
||||
username=str(payload.user or "openai-client"),
|
||||
stream_mode=False,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
return _error_response(
|
||||
str(exc),
|
||||
500,
|
||||
error_type="server_error",
|
||||
code="agent_execution_failed",
|
||||
)
|
||||
|
||||
content = "\n\n".join(
|
||||
message.strip()
|
||||
for message in agent.collected_messages
|
||||
if message and message.strip()
|
||||
).strip()
|
||||
if not content and result:
|
||||
content = str(result).strip()
|
||||
if not content:
|
||||
content = "未获得有效回复。"
|
||||
|
||||
created_at = int(time.time())
|
||||
response_id = f"resp_{uuid.uuid4().hex}"
|
||||
output_message = schemas.OpenAIResponsesOutputMessage(
|
||||
id=f"msg_{uuid.uuid4().hex}",
|
||||
content=[schemas.OpenAIResponsesOutputText(text=content)],
|
||||
)
|
||||
return schemas.OpenAIResponsesResponse(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=MODEL_ID,
|
||||
output=[output_message],
|
||||
usage=schemas.OpenAIUsage(),
|
||||
)
|
||||
@@ -12,6 +12,7 @@ from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
@@ -29,14 +30,14 @@ from app.db.user_oper import (
|
||||
get_current_active_superuser_async,
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.helper.llm import LLMHelper, LLMTestError, LLMTestTimeout
|
||||
from app.helper.image import ImageHelper
|
||||
from app.helper.llm import LLMHelper, LLMTestTimeout
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.system import SystemHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas import ConfigChangeEventData
|
||||
@@ -45,7 +46,6 @@ from app.utils.crypto import HashUtils
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.url import UrlUtils
|
||||
from pydantic import BaseModel
|
||||
from version import APP_VERSION
|
||||
|
||||
router = APIRouter()
|
||||
@@ -57,7 +57,7 @@ class LlmTestRequest(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
disable_thinking: Optional[bool] = None
|
||||
thinking_level: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
@@ -269,74 +269,6 @@ def _build_nettest_rules() -> list[dict[str, Any]]:
|
||||
return rules
|
||||
|
||||
|
||||
def _build_llm_test_data(
|
||||
duration_ms: Optional[int] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
构造 LLM 测试接口的基础返回数据。
|
||||
"""
|
||||
data = {
|
||||
"provider": provider if provider is not None else settings.LLM_PROVIDER,
|
||||
"model": model if model is not None else settings.LLM_MODEL,
|
||||
}
|
||||
if duration_ms is not None:
|
||||
data["duration_ms"] = duration_ms
|
||||
return data
|
||||
|
||||
|
||||
def _normalize_llm_test_value(
|
||||
value: Optional[str], *, empty_as_none: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
清理来自前端的 LLM 测试字段。
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
stripped = value.strip()
|
||||
if empty_as_none and not stripped:
|
||||
return None
|
||||
return stripped
|
||||
|
||||
|
||||
def _build_llm_test_snapshot(payload: Optional[LlmTestRequest] = None) -> dict[str, Any]:
|
||||
"""
|
||||
冻结当前 LLM 测试所需配置。
|
||||
|
||||
优先使用前端传入的临时参数;未传入时回退到已保存配置,兼容旧调用。
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER
|
||||
model = settings.LLM_MODEL
|
||||
disable_thinking = bool(getattr(settings, "LLM_DISABLE_THINKING", False))
|
||||
api_key = settings.LLM_API_KEY
|
||||
base_url = settings.LLM_BASE_URL
|
||||
enabled = bool(settings.AI_AGENT_ENABLE)
|
||||
|
||||
if payload:
|
||||
if payload.enabled is not None:
|
||||
enabled = bool(payload.enabled)
|
||||
if payload.provider is not None:
|
||||
provider = _normalize_llm_test_value(payload.provider) or ""
|
||||
if payload.model is not None:
|
||||
model = _normalize_llm_test_value(payload.model) or ""
|
||||
if payload.disable_thinking is not None:
|
||||
disable_thinking = bool(payload.disable_thinking)
|
||||
if payload.api_key is not None:
|
||||
api_key = _normalize_llm_test_value(payload.api_key, empty_as_none=True)
|
||||
if payload.base_url is not None:
|
||||
base_url = _normalize_llm_test_value(payload.base_url, empty_as_none=True)
|
||||
|
||||
return {
|
||||
"enabled": enabled,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"disable_thinking": disable_thinking,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
清理错误信息中的敏感字段,避免回显密钥。
|
||||
@@ -428,12 +360,12 @@ async def _close_nettest_response(response: Any) -> None:
|
||||
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None,
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None,
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
@@ -455,6 +387,7 @@ async def fetch_image(
|
||||
use_cache=use_cache,
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
@@ -467,16 +400,17 @@ async def fetch_image(
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
图片代理,可选是否使用代理服务器,支持 HTTP 缓存
|
||||
@@ -505,9 +439,9 @@ async def proxy_img(
|
||||
|
||||
@router.get("/cache/image", summary="图片缓存")
|
||||
async def cache_img(
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
@@ -601,7 +535,7 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
|
||||
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
|
||||
async def set_env_setting(
|
||||
env: dict, _: User = Depends(get_current_active_superuser_async)
|
||||
env: dict, _: User = Depends(get_current_active_superuser_async)
|
||||
):
|
||||
"""
|
||||
更新系统环境变量(仅管理员)
|
||||
@@ -636,9 +570,9 @@ async def set_env_setting(
|
||||
|
||||
@router.get("/progress/{process_type}", summary="实时进度")
|
||||
async def get_progress(
|
||||
request: Request,
|
||||
process_type: str,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
request: Request,
|
||||
process_type: str,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取处理进度,返回格式为SSE
|
||||
@@ -673,9 +607,9 @@ async def get_setting(key: str, _: User = Depends(get_current_active_user_async)
|
||||
|
||||
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
||||
async def set_setting(
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
更新系统设置(仅管理员)
|
||||
@@ -709,10 +643,10 @@ async def set_setting(
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(
|
||||
provider: str,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
provider: str,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
@@ -728,28 +662,33 @@ async def get_llm_models(
|
||||
|
||||
@router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response)
|
||||
async def llm_test(
|
||||
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
snapshot = _build_llm_test_snapshot(payload)
|
||||
data = _build_llm_test_data(
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
)
|
||||
if not snapshot["enabled"]:
|
||||
if not payload:
|
||||
return schemas.Response(success=False, message="请配置智能助手LLM相关参数后再进行测试")
|
||||
|
||||
if not payload.provider or not payload.model:
|
||||
return schemas.Response(success=False, message="请配置LLM提供商和模型")
|
||||
|
||||
data = {
|
||||
"provider": payload.provider,
|
||||
"model": payload.model,
|
||||
}
|
||||
if not payload.enabled:
|
||||
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
||||
|
||||
if not snapshot["api_key"]:
|
||||
if not payload.api_key or not payload.api_key.strip():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM API Key",
|
||||
data=data,
|
||||
)
|
||||
|
||||
if not (snapshot["model"] or "").strip():
|
||||
if not payload.model or not payload.model.strip():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM 模型",
|
||||
@@ -758,50 +697,36 @@ async def llm_test(
|
||||
|
||||
try:
|
||||
result = await LLMHelper.test_current_settings(
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
disable_thinking=snapshot["disable_thinking"],
|
||||
api_key=snapshot["api_key"],
|
||||
base_url=snapshot["base_url"],
|
||||
provider=payload.provider,
|
||||
model=payload.model,
|
||||
thinking_level=payload.thinking_level,
|
||||
api_key=payload.api_key,
|
||||
base_url=payload.base_url,
|
||||
)
|
||||
if not result.get("reply_preview"):
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="模型响应为空",
|
||||
data=_build_llm_test_data(
|
||||
result.get("duration_ms"),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
message="模型响应为空"
|
||||
)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except (LLMTestTimeout, TimeoutError) as err:
|
||||
logger.warning(err)
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="LLM 调用超时",
|
||||
data=_build_llm_test_data(
|
||||
getattr(err, "duration_ms", None),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
message="LLM 调用超时"
|
||||
)
|
||||
except Exception as err:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=_sanitize_llm_test_error(str(err), snapshot["api_key"]),
|
||||
data=_build_llm_test_data(
|
||||
getattr(err, "duration_ms", None),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
message=_sanitize_llm_test_error(str(err), payload.api_key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(
|
||||
request: Request,
|
||||
role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
request: Request,
|
||||
role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统消息,返回格式为SSE
|
||||
@@ -824,10 +749,10 @@ async def get_message(
|
||||
|
||||
@router.get("/logging", summary="实时日志")
|
||||
async def get_logging(
|
||||
request: Request,
|
||||
length: Optional[int] = 50,
|
||||
logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
request: Request,
|
||||
length: Optional[int] = 50,
|
||||
logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统日志
|
||||
@@ -838,7 +763,7 @@ async def get_logging(
|
||||
log_path = base_path / logfile
|
||||
|
||||
if not await SecurityUtils.async_is_safe_path(
|
||||
base_path=base_path, user_path=log_path, allowed_suffixes={".log"}
|
||||
base_path=base_path, user_path=log_path, allowed_suffixes={".log"}
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
@@ -855,7 +780,7 @@ async def get_logging(
|
||||
|
||||
# 读取历史日志
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 优化大文件读取策略
|
||||
if file_size > 100 * 1024:
|
||||
@@ -867,7 +792,7 @@ async def get_logging(
|
||||
# 找到第一个完整的行
|
||||
first_newline = content.find("\n")
|
||||
if first_newline != -1:
|
||||
content = content[first_newline + 1 :]
|
||||
content = content[first_newline + 1:]
|
||||
else:
|
||||
# 小文件直接读取全部内容
|
||||
content = await f.read()
|
||||
@@ -875,7 +800,7 @@ async def get_logging(
|
||||
# 按行分割并添加到队列,只保留非空行
|
||||
lines = [line.strip() for line in content.splitlines() if line.strip()]
|
||||
# 只取最后N行
|
||||
for line in lines[-max(length, 50) :]:
|
||||
for line in lines[-max(length, 50):]:
|
||||
lines_queue.append(line)
|
||||
|
||||
# 输出历史日志
|
||||
@@ -884,7 +809,7 @@ async def get_logging(
|
||||
|
||||
# 实时监听新日志
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 移动文件指针到文件末尾,继续监听新增内容
|
||||
await f.seek(0, 2)
|
||||
@@ -923,7 +848,7 @@ async def get_logging(
|
||||
try:
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as file:
|
||||
text = await file.read()
|
||||
# 倒序输出
|
||||
@@ -955,10 +880,10 @@ async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
|
||||
@router.get("/ruletest", summary="过滤规则测试", response_model=schemas.Response)
|
||||
def ruletest(
|
||||
title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
过滤规则测试,规则类型 1-订阅,2-洗版,3-搜索
|
||||
@@ -1013,11 +938,10 @@ async def nettest_targets(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
|
||||
@router.get("/nettest", summary="测试网络连通性")
|
||||
async def nettest(
|
||||
target_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
target_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
测试内置目标的网络连通性。
|
||||
|
||||
177
app/api/openai_utils.py
Normal file
177
app/api/openai_utils.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
||||
def _get_message_field(message: Any, field: str, default: Any = None) -> Any:
|
||||
if isinstance(message, dict):
|
||||
return message.get(field, default)
|
||||
return getattr(message, field, default)
|
||||
|
||||
|
||||
def extract_text_and_images(content: Any) -> Tuple[str, List[str]]:
|
||||
if content is None:
|
||||
return "", []
|
||||
if isinstance(content, str):
|
||||
return content.strip(), []
|
||||
|
||||
text_parts: List[str] = []
|
||||
image_urls: List[str] = []
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
normalized = item.strip()
|
||||
if normalized:
|
||||
text_parts.append(normalized)
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = (item.get("type") or "").lower()
|
||||
if item_type == "text":
|
||||
text = item.get("text")
|
||||
if text and str(text).strip():
|
||||
text_parts.append(str(text).strip())
|
||||
elif item_type == "input_text":
|
||||
text = item.get("text")
|
||||
if text and str(text).strip():
|
||||
text_parts.append(str(text).strip())
|
||||
elif item_type == "image_url":
|
||||
image_url = item.get("image_url")
|
||||
url = image_url.get("url") if isinstance(image_url, dict) else image_url
|
||||
if url and str(url).strip():
|
||||
image_urls.append(str(url).strip())
|
||||
elif item_type == "input_image":
|
||||
url = item.get("image_url")
|
||||
if url and str(url).strip():
|
||||
image_urls.append(str(url).strip())
|
||||
elif item_type == "image":
|
||||
source = item.get("source") or {}
|
||||
if isinstance(source, dict) and source.get("type") == "base64":
|
||||
data = source.get("data")
|
||||
media_type = source.get("media_type") or "image/png"
|
||||
if data and str(data).strip():
|
||||
image_urls.append(f"data:{media_type};base64,{str(data).strip()}")
|
||||
return "\n".join(text_parts).strip(), image_urls
|
||||
|
||||
|
||||
def build_prompt(messages: List[Any], use_server_session: bool) -> Tuple[str, List[str]]:
|
||||
system_texts: List[str] = []
|
||||
transcript: List[str] = []
|
||||
latest_user_text = ""
|
||||
latest_user_images: List[str] = []
|
||||
|
||||
for message in messages:
|
||||
role = str(_get_message_field(message, "role", "user") or "user").lower()
|
||||
if role == "developer":
|
||||
role = "system"
|
||||
text, images = extract_text_and_images(_get_message_field(message, "content"))
|
||||
if role == "system":
|
||||
if text:
|
||||
system_texts.append(text)
|
||||
continue
|
||||
if role == "user":
|
||||
if text or images:
|
||||
latest_user_text = text
|
||||
latest_user_images = images
|
||||
if text:
|
||||
transcript.append(f"user: {text}")
|
||||
continue
|
||||
if text:
|
||||
transcript.append(f"{role}: {text}")
|
||||
|
||||
if not latest_user_text and not latest_user_images:
|
||||
raise ValueError("No usable user message found in messages.")
|
||||
|
||||
prompt_parts: List[str] = []
|
||||
if system_texts:
|
||||
prompt_parts.append("系统要求:\n" + "\n\n".join(system_texts))
|
||||
|
||||
if not use_server_session and transcript:
|
||||
history = transcript[:-1] if transcript[-1].startswith("user: ") else transcript
|
||||
if history:
|
||||
prompt_parts.append("对话上下文:\n" + "\n".join(history[-10:]))
|
||||
|
||||
if latest_user_text:
|
||||
prompt_parts.append("当前用户消息:\n" + latest_user_text)
|
||||
else:
|
||||
prompt_parts.append("当前用户消息:\n请结合图片内容回复。")
|
||||
|
||||
return "\n\n".join(part for part in prompt_parts if part).strip(), latest_user_images
|
||||
|
||||
|
||||
def build_session_id(session_key: str, prefix: str) -> str:
|
||||
digest = hashlib.sha256(session_key.encode("utf-8")).hexdigest()
|
||||
return f"{prefix}{digest[:32]}"
|
||||
|
||||
|
||||
def build_completion_payload(content: str, model_id: str) -> Dict[str, Any]:
|
||||
created = int(time.time())
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model_id,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_responses_input(
|
||||
input_data: Any, instructions: str | None = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if instructions and str(instructions).strip():
|
||||
messages.append({"role": "system", "content": str(instructions).strip()})
|
||||
|
||||
if isinstance(input_data, str):
|
||||
normalized = input_data.strip()
|
||||
if normalized:
|
||||
messages.append({"role": "user", "content": normalized})
|
||||
return messages
|
||||
|
||||
if isinstance(input_data, list):
|
||||
for item in input_data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = (item.get("type") or "").lower()
|
||||
if item_type == "message":
|
||||
role = item.get("role") or "user"
|
||||
content = item.get("content")
|
||||
messages.append({"role": role, "content": content})
|
||||
elif item.get("role") and "content" in item:
|
||||
messages.append({"role": item.get("role"), "content": item.get("content")})
|
||||
return messages
|
||||
|
||||
if isinstance(input_data, dict) and input_data.get("role") and "content" in input_data:
|
||||
messages.append({"role": input_data.get("role"), "content": input_data.get("content")})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_anthropic_messages(
|
||||
system: Any, messages: List[Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
system_text, _ = extract_text_and_images(system)
|
||||
if system_text:
|
||||
normalized.append({"role": "system", "content": system_text})
|
||||
|
||||
for message in messages:
|
||||
role = _get_message_field(message, "role", "user")
|
||||
content = _get_message_field(message, "content")
|
||||
normalized.append({"role": role, "content": content})
|
||||
return normalized
|
||||
@@ -1407,6 +1407,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
chat_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
编辑已发送的消息
|
||||
@@ -1416,6 +1417,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param chat_id: 聊天ID
|
||||
:param text: 新的消息内容
|
||||
:param title: 消息标题
|
||||
:param buttons: 更新后的按钮列表
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
return self.run_module(
|
||||
@@ -1426,6 +1428,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
title=title,
|
||||
buttons=buttons,
|
||||
)
|
||||
|
||||
def send_direct_message(self, message: Notification) -> Optional[MessageResponse]:
|
||||
|
||||
1363
app/chain/interaction.py
Normal file
1363
app/chain/interaction.py
Normal file
File diff suppressed because it is too large
Load Diff
1192
app/chain/message.py
1192
app/chain/message.py
File diff suppressed because it is too large
Load Diff
@@ -566,8 +566,8 @@ class SearchChain(ChainBase):
|
||||
) or []
|
||||
)
|
||||
search_count += 1
|
||||
# 有结果则停止
|
||||
if torrents:
|
||||
# 未开启多名称搜索时,有结果则停止
|
||||
if not settings.SEARCH_MULTIPLE_NAME and torrents:
|
||||
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
|
||||
break
|
||||
|
||||
@@ -654,7 +654,7 @@ class SearchChain(ChainBase):
|
||||
}
|
||||
|
||||
search_count += 1
|
||||
if torrents:
|
||||
if not settings.SEARCH_MULTIPLE_NAME and torrents:
|
||||
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
|
||||
break
|
||||
|
||||
|
||||
1241
app/chain/skills.py
Normal file
1241
app/chain/skills.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.message import MessageChain
|
||||
from app.chain.site import SiteChain
|
||||
from app.chain.skills import SkillsChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.chain.transfer import TransferChain
|
||||
@@ -154,6 +155,18 @@ class Command(metaclass=Singleton):
|
||||
"category": "管理",
|
||||
"data": {},
|
||||
},
|
||||
"/session_status": {
|
||||
"func": MessageChain().remote_session_status,
|
||||
"description": "会话状态",
|
||||
"category": "智能体",
|
||||
"data": {},
|
||||
},
|
||||
"/skills": {
|
||||
"func": SkillsChain().remote_manage,
|
||||
"description": "管理技能",
|
||||
"category": "智能体",
|
||||
"data": {},
|
||||
},
|
||||
}
|
||||
# 插件命令集合
|
||||
self._plugin_commands = {}
|
||||
|
||||
@@ -420,6 +420,15 @@ class ConfigModel(BaseModel):
|
||||
# 本地插件仓库目录,多个地址使用,分隔
|
||||
PLUGIN_LOCAL_REPO_PATHS: Optional[str] = None
|
||||
|
||||
# ==================== 技能配置 ====================
|
||||
# 技能市场仓库地址,多个地址使用,分隔
|
||||
SKILL_MARKET: str = (
|
||||
"https://clawhub.ai,"
|
||||
"https://github.com/openai/skills,"
|
||||
"https://github.com/anthropics/skills,"
|
||||
"https://github.com/vercel-labs/agent-skills"
|
||||
)
|
||||
|
||||
# ==================== Github & PIP ====================
|
||||
# Github token,提高请求api限流阈值 ghp_****
|
||||
GITHUB_TOKEN: Optional[str] = None
|
||||
@@ -496,8 +505,8 @@ class ConfigModel(BaseModel):
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# 是否尽量关闭模型的思考/推理能力(按各 provider/model 支持情况自动适配)
|
||||
LLM_DISABLE_THINKING: bool = True
|
||||
# 思考模式/深度配置:off/auto/minimal/low/medium/high/max/xhigh
|
||||
LLM_THINKING_LEVEL: Optional[str] = 'off'
|
||||
# LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型
|
||||
LLM_SUPPORT_IMAGE_INPUT: bool = True
|
||||
# LLM API密钥
|
||||
|
||||
@@ -13,7 +13,7 @@ from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import HTTPException, status, Security, Request, Response
|
||||
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie
|
||||
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie, HTTPBearer
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app import schemas
|
||||
@@ -42,6 +42,12 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
|
||||
# API KEY 通过 QUERY 认证
|
||||
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
|
||||
|
||||
# OpenAI compatible Bearer Token 认证
|
||||
openai_bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Anthropic compatible API Key 认证
|
||||
anthropic_api_key_header = APIKeyHeader(name="x-api-key", auto_error=False, scheme_name="anthropic_api_key_header")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
|
||||
@@ -2,9 +2,13 @@
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
@@ -70,21 +74,120 @@ def _get_httpx_proxy_key() -> str:
|
||||
if "proxy" in params:
|
||||
return "proxy"
|
||||
return "proxies"
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"检测 httpx 代理参数失败,默认使用 'proxies':{e}")
|
||||
return "proxies"
|
||||
|
||||
|
||||
def _deepseek_thinking_toggle(extra_body: Any) -> bool | None:
|
||||
"""
|
||||
解析 DeepSeek extra_body 中显式传入的 thinking 开关。
|
||||
"""
|
||||
if not isinstance(extra_body, dict):
|
||||
return None
|
||||
|
||||
thinking = extra_body.get("thinking")
|
||||
if not isinstance(thinking, dict):
|
||||
return None
|
||||
|
||||
thinking_type = str(thinking.get("type") or "").strip().lower()
|
||||
if thinking_type == "enabled":
|
||||
return True
|
||||
if thinking_type == "disabled":
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def _is_deepseek_thinking_enabled(model_name: str | None, extra_body: Any) -> bool:
|
||||
"""
|
||||
判断本次 DeepSeek 调用是否处于 thinking mode。
|
||||
"""
|
||||
explicit_toggle = _deepseek_thinking_toggle(extra_body)
|
||||
if explicit_toggle is not None:
|
||||
return explicit_toggle
|
||||
|
||||
normalized_model_name = str(model_name or "").strip().lower()
|
||||
if normalized_model_name == "deepseek-reasoner":
|
||||
return True
|
||||
if normalized_model_name.startswith("deepseek-v4-"):
|
||||
# DeepSeek V4 默认启用 thinking mode,除非显式关闭。
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _patch_deepseek_reasoning_content_support():
|
||||
"""
|
||||
修补 langchain-deepseek 在 tool-call 场景下遗漏 reasoning_content 回传的问题。
|
||||
|
||||
DeepSeek thinking mode 要求:若 assistant 历史消息包含 tool_calls,
|
||||
后续请求中必须带回该条消息的顶层 reasoning_content。
|
||||
某些 langchain-deepseek 版本虽然能从响应中拿到 reasoning_content,
|
||||
但不会在重放消息历史时写回请求载荷,导致 400。
|
||||
"""
|
||||
try:
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-deepseek reasoning_content 修补:{err}")
|
||||
return
|
||||
|
||||
if getattr(ChatDeepSeek, "_moviepilot_reasoning_content_patched", False):
|
||||
return
|
||||
|
||||
original_get_request_payload = getattr(ChatDeepSeek, "_get_request_payload", None)
|
||||
if not callable(original_get_request_payload):
|
||||
logger.warning("langchain-deepseek 缺少 _get_request_payload,无法修补 reasoning_content")
|
||||
return
|
||||
|
||||
@wraps(original_get_request_payload)
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
|
||||
# Resolve original messages so we can extract reasoning_content from
|
||||
# additional_kwargs. The parent's payload builder does not propagate
|
||||
# this DeepSeek-specific field.
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
|
||||
for i, message in enumerate(payload["messages"]):
|
||||
if message["role"] == "tool" and isinstance(message["content"], list):
|
||||
message["content"] = json.dumps(message["content"])
|
||||
elif message["role"] == "assistant":
|
||||
if isinstance(message["content"], list):
|
||||
# DeepSeek API expects assistant content to be a string,
|
||||
# not a list. Extract text blocks and join them, or use
|
||||
# empty string if none exist.
|
||||
text_parts = [
|
||||
block.get("text", "")
|
||||
for block in message["content"]
|
||||
if isinstance(block, dict) and block.get("type") == "text"
|
||||
]
|
||||
message["content"] = "".join(text_parts) if text_parts else ""
|
||||
|
||||
# DeepSeek reasoning models require every assistant message to
|
||||
# carry a reasoning_content field (even when empty). The value
|
||||
# is stored in AIMessage.additional_kwargs by
|
||||
# _create_chat_result(); re-inject it into the API payload.
|
||||
if (
|
||||
"reasoning_content" not in message
|
||||
and i < len(messages)
|
||||
and isinstance(messages[i], AIMessage)
|
||||
):
|
||||
message["reasoning_content"] = messages[i].additional_kwargs.get(
|
||||
"reasoning_content", ""
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
ChatDeepSeek._get_request_payload = _patched_get_request_payload
|
||||
ChatDeepSeek._moviepilot_reasoning_content_patched = True
|
||||
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def _should_disable_thinking(disable_thinking: bool | None = None) -> bool:
|
||||
"""
|
||||
判断本次调用是否应尝试关闭模型思考能力。
|
||||
"""
|
||||
if disable_thinking is not None:
|
||||
return bool(disable_thinking)
|
||||
return bool(getattr(settings, "LLM_DISABLE_THINKING", False))
|
||||
_SUPPORTED_THINKING_LEVELS = frozenset(
|
||||
{"off", "auto", "minimal", "low", "medium", "high", "max", "xhigh"}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_name(model_name: str | None) -> str:
|
||||
@@ -94,48 +197,164 @@ class LLMHelper:
|
||||
return (model_name or "").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def _build_disabled_thinking_kwargs(
|
||||
cls,
|
||||
provider: str,
|
||||
model: str | None,
|
||||
disable_thinking: bool | None = None,
|
||||
def _normalize_deepseek_reasoning_effort(
|
||||
cls, thinking_level: str | None = None
|
||||
) -> str | None:
|
||||
"""
|
||||
DeepSeek 文档当前建议使用 high/max;兼容常见 effort 别名。
|
||||
"""
|
||||
if not thinking_level or thinking_level in {"off", "auto"}:
|
||||
return None
|
||||
|
||||
if thinking_level in {"minimal", "low", "medium", "high"}:
|
||||
return "high"
|
||||
if thinking_level in {"max", "xhigh"}:
|
||||
return "max"
|
||||
|
||||
logger.warning(f"忽略不支持的 DeepSeek reasoning_effort 配置: {thinking_level}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _normalize_openai_reasoning_effort(
|
||||
cls, thinking_level: str | None = None
|
||||
) -> str | None:
|
||||
"""
|
||||
OpenAI reasoning_effort 支持更细粒度的 effort,统一做最近似映射。
|
||||
"""
|
||||
if not thinking_level or thinking_level == "auto":
|
||||
return None
|
||||
if thinking_level == "off":
|
||||
return "none"
|
||||
if thinking_level == "max":
|
||||
return "xhigh"
|
||||
return thinking_level
|
||||
|
||||
@classmethod
|
||||
def _build_google_thinking_kwargs(
|
||||
cls, model_name: str, thinking_level: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
按 provider/model 生成“禁用思考”相关参数。
|
||||
|
||||
优先使用 LangChain/OpenAI SDK 已支持的原生字段;仅在 provider
|
||||
明确要求自定义请求体时,才回退到 extra_body。
|
||||
Gemini 3 使用 thinking_level;Gemini 2.5 使用 thinking_budget。
|
||||
"""
|
||||
if not cls._should_disable_thinking(disable_thinking):
|
||||
if not model_name or thinking_level == "auto":
|
||||
return {}
|
||||
|
||||
provider_name = (provider or "").strip().lower()
|
||||
model_name = cls._normalize_model_name(model)
|
||||
if not model_name:
|
||||
return {}
|
||||
|
||||
# Moonshot Kimi K2.5/K2.6 需要在请求体显式声明 thinking.disabled。
|
||||
if model_name.startswith(("kimi-k2.5", "kimi-k2.6")):
|
||||
return {"extra_body": {"thinking": {"type": "disabled"}}}
|
||||
|
||||
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
|
||||
if provider_name == "openai" and model_name.startswith(
|
||||
("gpt-5", "o1", "o3", "o4")
|
||||
):
|
||||
return {"reasoning_effort": "none"}
|
||||
|
||||
# Gemini 使用 google-genai / langchain-google-genai 内置思考控制参数。
|
||||
if provider_name == "google":
|
||||
if "gemini-2.5" in model_name:
|
||||
if "gemini-2.5" in model_name:
|
||||
if thinking_level == "off":
|
||||
if "pro" in model_name:
|
||||
# Gemini 2.5 Pro 官方不支持完全关闭思考,回退到最小预算。
|
||||
return {
|
||||
"thinking_budget": 128,
|
||||
"include_thoughts": False,
|
||||
}
|
||||
return {
|
||||
"thinking_budget": 0,
|
||||
"include_thoughts": False,
|
||||
}
|
||||
if "gemini-3" in model_name:
|
||||
return {
|
||||
"thinking_level": "minimal",
|
||||
|
||||
budget_map = {
|
||||
"minimal": 512,
|
||||
"low": 1024,
|
||||
"medium": 4096,
|
||||
"high": 8192,
|
||||
"max": 24576,
|
||||
"xhigh": 24576,
|
||||
}
|
||||
budget = budget_map.get(thinking_level)
|
||||
return (
|
||||
{
|
||||
"thinking_budget": budget,
|
||||
"include_thoughts": False,
|
||||
}
|
||||
if budget is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
if "gemini-3" in model_name:
|
||||
level_map = {
|
||||
"off": "minimal",
|
||||
"minimal": "minimal",
|
||||
"low": "low",
|
||||
"medium": "medium",
|
||||
"high": "high",
|
||||
"max": "high",
|
||||
"xhigh": "high",
|
||||
}
|
||||
google_level = level_map.get(thinking_level)
|
||||
return (
|
||||
{
|
||||
"thinking_level": google_level,
|
||||
"include_thoughts": False,
|
||||
}
|
||||
if google_level
|
||||
else {}
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _build_kimi_thinking_kwargs(
|
||||
cls, model_name: str, thinking_level: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Kimi 当前公开文档仅支持思考开关,不支持显式深度调节。
|
||||
"""
|
||||
if model_name.startswith("kimi-k2-thinking"):
|
||||
return {}
|
||||
if thinking_level == "off":
|
||||
return {"extra_body": {"thinking": {"type": "disabled"}}}
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _build_thinking_kwargs(
|
||||
cls,
|
||||
provider: str,
|
||||
model: str | None,
|
||||
thinking_level: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
按 provider/model 生成思考模式相关参数。
|
||||
|
||||
优先使用 LangChain/OpenAI SDK 已支持的原生字段;仅在 provider
|
||||
明确要求自定义请求体时,才回退到 extra_body。
|
||||
"""
|
||||
provider_name = (provider or "").strip().lower()
|
||||
model_name = cls._normalize_model_name(model)
|
||||
|
||||
if provider_name == "deepseek":
|
||||
if thinking_level == "off":
|
||||
return {"extra_body": {"thinking": {"type": "disabled"}}}
|
||||
if thinking_level == "auto":
|
||||
return {}
|
||||
|
||||
kwargs: dict[str, Any] = {"extra_body": {"thinking": {"type": "enabled"}}}
|
||||
deepseek_effort = cls._normalize_deepseek_reasoning_effort(
|
||||
thinking_level
|
||||
)
|
||||
if deepseek_effort:
|
||||
kwargs["reasoning_effort"] = deepseek_effort
|
||||
return kwargs
|
||||
|
||||
if model_name.startswith(("kimi-k2.5", "kimi-k2.6", "kimi-k2-thinking")):
|
||||
return cls._build_kimi_thinking_kwargs(model_name, thinking_level)
|
||||
|
||||
if not model_name:
|
||||
return {}
|
||||
|
||||
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
|
||||
if provider_name == "openai" and model_name.startswith(
|
||||
("gpt-5", "o1", "o3", "o4")
|
||||
):
|
||||
openai_effort = cls._normalize_openai_reasoning_effort(
|
||||
thinking_level
|
||||
)
|
||||
return {"reasoning_effort": openai_effort} if openai_effort else {}
|
||||
|
||||
# Gemini 使用 google-genai / langchain-google-genai 内置思考控制参数。
|
||||
if provider_name == "google":
|
||||
return cls._build_google_thinking_kwargs(
|
||||
model_name, thinking_level
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
@@ -148,16 +367,26 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_llm(
|
||||
streaming: bool = False,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
disable_thinking: bool | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
streaming: bool = False,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
thinking_level: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
):
|
||||
"""
|
||||
获取LLM实例
|
||||
:param streaming: 是否启用流式输出
|
||||
:param provider: LLM提供商,默认为配置项LLM_PROVIDER
|
||||
:param model: 模型名称,默认为配置项LLM_MODEL
|
||||
:param thinking_level: 思考模式级别,默认为 None(即自动判断
|
||||
是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal"、"low"、"medium"、"high"、"max"/"xhigh"(最大)。
|
||||
不同模型对思考模式的支持和表现不同,具体映射关系请
|
||||
参考代码实现。对于不支持思考模式的模型,该参数将被忽略。
|
||||
:param api_key: API Key,默认为
|
||||
配置项LLM_API_KEY。对于某些提供商(
|
||||
如 DeepSeek),可能需要同时提供 base_url。
|
||||
:param base_url: API Base URL,默认为配置项LLM_BASE_URL。
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider_name = str(
|
||||
@@ -166,10 +395,10 @@ class LLMHelper:
|
||||
model_name = model if model is not None else settings.LLM_MODEL
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
thinking_kwargs = LLMHelper._build_disabled_thinking_kwargs(
|
||||
thinking_kwargs = LLMHelper._build_thinking_kwargs(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
disable_thinking=disable_thinking,
|
||||
thinking_level=thinking_level
|
||||
)
|
||||
|
||||
if not api_key_value:
|
||||
@@ -201,9 +430,11 @@ class LLMHelper:
|
||||
elif provider_name == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
_patch_deepseek_reasoning_content_support()
|
||||
model = ChatDeepSeek(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
api_base=base_url_value,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
@@ -231,7 +462,7 @@ class LLMHelper:
|
||||
else:
|
||||
model.profile = {
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
|
||||
* 1000, # 转换为token单位
|
||||
* 1000, # 转换为token单位
|
||||
}
|
||||
|
||||
return model
|
||||
@@ -255,10 +486,10 @@ class LLMHelper:
|
||||
if isinstance(block, dict) or hasattr(block, "get"):
|
||||
block_type = block.get("type")
|
||||
if block.get("thought") or block_type in (
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
):
|
||||
continue
|
||||
if block_type == "text":
|
||||
@@ -278,13 +509,13 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
async def test_current_settings(
|
||||
prompt: str = "请只回复 OK",
|
||||
timeout: int = 20,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
disable_thinking: bool | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
prompt: str = "请只回复 OK",
|
||||
timeout: int = 20,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
thinking_level: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
使用当前已保存配置执行一次最小 LLM 调用。
|
||||
@@ -298,7 +529,7 @@ class LLMHelper:
|
||||
streaming=False,
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
disable_thinking=disable_thinking,
|
||||
thinking_level=thinking_level,
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
)
|
||||
@@ -326,7 +557,7 @@ class LLMHelper:
|
||||
return data
|
||||
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
@@ -364,7 +595,7 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
|
||||
1175
app/helper/skill.py
Normal file
1175
app/helper/skill.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -439,6 +439,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
chat_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
编辑消息
|
||||
@@ -448,6 +449,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
:param chat_id: 聊天ID
|
||||
:param text: 新的消息内容
|
||||
:param title: 消息标题
|
||||
:param buttons: 新的按钮列表
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
@@ -460,6 +462,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
result = client.send_msg(
|
||||
title=title or "",
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
original_message_id=message_id,
|
||||
original_chat_id=str(chat_id),
|
||||
)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import re
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from lxml import etree
|
||||
@@ -11,6 +13,121 @@ from app.utils.string import StringUtils
|
||||
class NexusAudiencesSiteUserInfo(NexusPhpSiteUserInfo):
|
||||
schema = SiteSchema.NexusAudiences
|
||||
|
||||
def _parse_user_traffic_info(self, html_text):
|
||||
"""
|
||||
解析用户流量信息
|
||||
"""
|
||||
super()._parse_user_traffic_info(html_text)
|
||||
self.__parse_userbar_info(html_text)
|
||||
|
||||
def _parse_user_detail_info(self, html_text: str):
|
||||
"""
|
||||
解析用户额外信息
|
||||
"""
|
||||
super()._parse_user_detail_info(html_text)
|
||||
self.__parse_userbar_info(html_text)
|
||||
|
||||
def __parse_userbar_info(self, html_text: str):
|
||||
"""
|
||||
解析 Audiences 新版顶部用户栏,覆盖 NexusPHP 通用正则的误判。
|
||||
"""
|
||||
html = etree.HTML(html_text)
|
||||
try:
|
||||
if not StringUtils.is_valid_html_element(html):
|
||||
return
|
||||
|
||||
for user_node in html.xpath('//*[@data-uploader-url or @data-uploader-stats]'):
|
||||
self.__parse_user_identity(user_node)
|
||||
self.__parse_uploader_stats(user_node.get("data-uploader-stats"))
|
||||
|
||||
# data-uploader-stats 不包含分享率,需从 compact metric 的 class 中读取。
|
||||
self.__parse_compact_metric(html, "ratio", "ratio")
|
||||
self.__parse_compact_metric(html, "uploaded", "upload")
|
||||
self.__parse_compact_metric(html, "downloaded", "download")
|
||||
self.__parse_compact_metric(html, "bonus", "bonus")
|
||||
self.__parse_compact_metric(html, "active", "active")
|
||||
finally:
|
||||
if html is not None:
|
||||
del html
|
||||
|
||||
def __parse_user_identity(self, user_node):
|
||||
"""
|
||||
从新版用户卡属性中提取用户 ID、用户名和等级。
|
||||
"""
|
||||
user_url = user_node.get("data-uploader-url") or ""
|
||||
user_detail = re.search(r"userdetails\.php\?id=(\d+)", user_url)
|
||||
if user_detail and user_detail.group(1).strip():
|
||||
self.userid = user_detail.group(1).strip()
|
||||
|
||||
username = user_node.get("data-uploader-label")
|
||||
if username and username.strip():
|
||||
self.username = username.strip()
|
||||
|
||||
user_level = user_node.get("data-uploader-badge")
|
||||
if user_level and user_level.strip():
|
||||
self.user_level = user_level.strip()
|
||||
|
||||
def __parse_uploader_stats(self, stats_text: str):
|
||||
"""
|
||||
解析 data-uploader-stats 中的结构化流量数据。
|
||||
"""
|
||||
if not stats_text:
|
||||
return
|
||||
|
||||
try:
|
||||
stats = json.loads(stats_text)
|
||||
except (TypeError, ValueError):
|
||||
return
|
||||
|
||||
if not isinstance(stats, list):
|
||||
return
|
||||
|
||||
for item in stats:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
label = str(item.get("label") or "").strip(" ::")
|
||||
tone = str(item.get("tone") or "").strip()
|
||||
value = str(item.get("value") or "").strip()
|
||||
self.__set_metric_value(label=label, tone=tone, value=value)
|
||||
|
||||
def __parse_compact_metric(self, html, metric: str, field: str):
|
||||
"""
|
||||
按 compact metric 的 class 读取新版用户栏中的单项数据。
|
||||
"""
|
||||
values = html.xpath(
|
||||
f'//*[contains(concat(" ", normalize-space(@class), " "), " site-userbar__compact-metric--{metric} ")]'
|
||||
'//span[normalize-space()][last()]/text()'
|
||||
)
|
||||
if not values:
|
||||
values = html.xpath(
|
||||
f'//*[contains(concat(" ", normalize-space(@class), " "), " site-userbar__compact-metric--{metric} ")]'
|
||||
'/text()'
|
||||
)
|
||||
if values:
|
||||
self.__set_metric_value(field=field, value=values[-1].strip())
|
||||
|
||||
def __set_metric_value(self, value: str, label: str = None, tone: str = None, field: str = None):
|
||||
"""
|
||||
将 Audiences 用户栏指标写入通用用户数据字段。
|
||||
"""
|
||||
if not value:
|
||||
return
|
||||
|
||||
metric_key = field or tone or label
|
||||
if metric_key in {"uploaded", "上传量", "upload"}:
|
||||
self.upload = StringUtils.num_filesize(value)
|
||||
elif metric_key in {"downloaded", "下载量", "download"}:
|
||||
self.download = StringUtils.num_filesize(value)
|
||||
elif metric_key in {"bonus", "爆米花"}:
|
||||
self.bonus = StringUtils.str_float(value)
|
||||
elif metric_key == "ratio":
|
||||
self.ratio = StringUtils.str_float(value)
|
||||
elif metric_key in {"active", "活跃"}:
|
||||
active_match = re.search(r"↑\s*(\d+)\s*/\s*↓\s*(\d+)", value)
|
||||
if active_match:
|
||||
self.seeding = StringUtils.str_int(active_match.group(1))
|
||||
self.leeching = StringUtils.str_int(active_match.group(2))
|
||||
|
||||
def _parse_seeding_pages(self):
|
||||
if not self._torrent_seeding_page:
|
||||
return
|
||||
|
||||
@@ -557,6 +557,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
chat_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
编辑消息
|
||||
@@ -566,6 +567,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
:param chat_id: 聊天ID
|
||||
:param text: 新的消息内容
|
||||
:param title: 消息标题
|
||||
:param buttons: 新的按钮列表
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
@@ -578,6 +580,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
result = client.send_msg(
|
||||
title=title or "",
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
original_message_id=str(message_id),
|
||||
original_chat_id=str(chat_id),
|
||||
)
|
||||
|
||||
@@ -564,6 +564,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
chat_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
编辑消息
|
||||
@@ -573,6 +574,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
:param chat_id: 聊天ID
|
||||
:param text: 新的消息内容
|
||||
:param title: 消息标题
|
||||
:param buttons: 新的按钮列表
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
@@ -587,6 +589,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
message_id=message_id,
|
||||
text=text,
|
||||
title=title,
|
||||
buttons=buttons,
|
||||
)
|
||||
if result:
|
||||
return True
|
||||
|
||||
@@ -835,6 +835,7 @@ class Telegram:
|
||||
message_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑Telegram消息(公开方法)
|
||||
@@ -842,6 +843,7 @@ class Telegram:
|
||||
:param message_id: 消息ID
|
||||
:param text: 新的消息内容
|
||||
:param title: 消息标题
|
||||
:param buttons: 新的按钮列表
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if not self._bot:
|
||||
@@ -861,6 +863,7 @@ class Telegram:
|
||||
chat_id=str(chat_id),
|
||||
message_id=int(message_id),
|
||||
text=caption,
|
||||
buttons=buttons,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"编辑Telegram消息异常: {str(e)}")
|
||||
|
||||
@@ -11,6 +11,7 @@ from .monitoring import *
|
||||
from .plugin import *
|
||||
from .response import *
|
||||
from .rule import *
|
||||
from .openai import *
|
||||
from .servarr import *
|
||||
from .servcookie import *
|
||||
from .site import *
|
||||
@@ -23,4 +24,3 @@ from .transfer import *
|
||||
from .user import *
|
||||
from .workflow import *
|
||||
from .mcp import *
|
||||
|
||||
|
||||
156
app/schemas/openai.py
Normal file
156
app/schemas/openai.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class OpenAIModelInfo(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int
|
||||
owned_by: str = "moviepilot"
|
||||
|
||||
|
||||
class OpenAIModelListResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[OpenAIModelInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Any
|
||||
name: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIChatCompletionsRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[OpenAIChatMessage]
|
||||
user: Optional[str] = None
|
||||
stream: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIResponsesRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
input: Any
|
||||
instructions: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
stream: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIChatChoiceMessage(BaseModel):
|
||||
role: str = "assistant"
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAIChatChoice(BaseModel):
|
||||
index: int = 0
|
||||
message: OpenAIChatChoiceMessage
|
||||
finish_reason: str = "stop"
|
||||
|
||||
|
||||
class OpenAIUsage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class OpenAIChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[OpenAIChatChoice]
|
||||
usage: OpenAIUsage
|
||||
|
||||
|
||||
class OpenAIResponsesOutputText(BaseModel):
|
||||
type: str = "output_text"
|
||||
text: str
|
||||
annotations: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIResponsesOutputMessage(BaseModel):
|
||||
id: str
|
||||
type: str = "message"
|
||||
status: str = "completed"
|
||||
role: str = "assistant"
|
||||
content: List[OpenAIResponsesOutputText] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIResponsesResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "response"
|
||||
created_at: int
|
||||
status: str = "completed"
|
||||
model: str
|
||||
output: List[OpenAIResponsesOutputMessage] = Field(default_factory=list)
|
||||
error: Optional[Any] = None
|
||||
incomplete_details: Optional[Any] = None
|
||||
usage: OpenAIUsage
|
||||
|
||||
|
||||
class OpenAIErrorDetail(BaseModel):
|
||||
message: str
|
||||
type: str = "invalid_request_error"
|
||||
param: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIErrorResponse(BaseModel):
|
||||
error: OpenAIErrorDetail
|
||||
|
||||
|
||||
OpenAIChatContentPart = Dict[str, Any]
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
role: str
|
||||
content: Any
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AnthropicMessagesRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[AnthropicMessage]
|
||||
system: Optional[Any] = None
|
||||
max_tokens: Optional[int] = 1024
|
||||
stream: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AnthropicTextBlock(BaseModel):
|
||||
type: str = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class AnthropicUsage(BaseModel):
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
|
||||
class AnthropicMessagesResponse(BaseModel):
|
||||
id: str
|
||||
type: str = "message"
|
||||
role: str = "assistant"
|
||||
content: List[AnthropicTextBlock] = Field(default_factory=list)
|
||||
model: str
|
||||
stop_reason: str = "end_turn"
|
||||
stop_sequence: Optional[str] = None
|
||||
usage: AnthropicUsage = Field(default_factory=AnthropicUsage)
|
||||
|
||||
|
||||
class AnthropicErrorDetail(BaseModel):
|
||||
type: str = "invalid_request_error"
|
||||
message: str
|
||||
|
||||
|
||||
class AnthropicErrorResponse(BaseModel):
|
||||
type: str = "error"
|
||||
error: AnthropicErrorDetail
|
||||
@@ -76,14 +76,14 @@ pympler~=1.1
|
||||
smbprotocol~=1.15.0
|
||||
setproctitle~=1.3.6
|
||||
httpx[socks]~=0.28.1
|
||||
langchain~=1.2.13
|
||||
langchain-core~=1.2.20
|
||||
langchain~=1.2.15
|
||||
langchain-core~=1.3.2
|
||||
langchain-community~=0.4.1
|
||||
langchain-openai~=1.1.11
|
||||
langchain-google-genai~=4.2.1
|
||||
langchain-openai~=1.2.1
|
||||
langchain-google-genai~=4.2.2
|
||||
langchain-deepseek~=1.0.1
|
||||
langgraph~=1.1.3
|
||||
openai~=2.29.0
|
||||
google-genai~=1.68.0
|
||||
langgraph~=1.1.9
|
||||
openai~=2.32.0
|
||||
google-genai~=1.73.1
|
||||
ddgs~=9.10.0
|
||||
websocket-client~=1.8.0
|
||||
|
||||
@@ -1063,6 +1063,32 @@ def _prompt_choice(label: str, choices: dict[str, str], default: str) -> str:
|
||||
print("请输入列表中的可选值。")
|
||||
|
||||
|
||||
def _env_llm_thinking_level_default() -> str:
|
||||
value = _normalize_choice(_env_default("LLM_THINKING_LEVEL", ""))
|
||||
alias_map = {
|
||||
"none": "off",
|
||||
"disabled": "off",
|
||||
"disable": "off",
|
||||
"enabled": "auto",
|
||||
"enable": "auto",
|
||||
"default": "auto",
|
||||
"dynamic": "auto",
|
||||
}
|
||||
normalized = alias_map.get(value, value)
|
||||
if normalized in {
|
||||
"off",
|
||||
"auto",
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"max",
|
||||
"xhigh",
|
||||
}:
|
||||
return normalized
|
||||
return "auto"
|
||||
|
||||
|
||||
def _prompt_path(label: str, *, default: Path, allow_empty: bool = False) -> str:
|
||||
value = _prompt_text(label, default=str(default), allow_empty=allow_empty)
|
||||
if not value:
|
||||
@@ -1476,9 +1502,19 @@ def _collect_agent_config() -> dict[str, Any]:
|
||||
current_value=read_env_value("LLM_API_KEY"),
|
||||
required=True,
|
||||
),
|
||||
"LLM_DISABLE_THINKING": _prompt_yes_no(
|
||||
"是否尽量关闭模型思考/推理",
|
||||
default=_env_bool("LLM_DISABLE_THINKING", False),
|
||||
"LLM_THINKING_LEVEL": _prompt_choice(
|
||||
"LLM 思考模式/深度",
|
||||
choices={
|
||||
"off": "关闭思考",
|
||||
"auto": "自动",
|
||||
"minimal": "最小",
|
||||
"low": "低",
|
||||
"medium": "中",
|
||||
"high": "高",
|
||||
"max": "极高",
|
||||
"xhigh": "超高",
|
||||
},
|
||||
default=_env_llm_thinking_level_default(),
|
||||
),
|
||||
"LLM_SUPPORT_IMAGE_INPUT": _prompt_yes_no(
|
||||
"是否启用图片输入支持",
|
||||
@@ -1506,7 +1542,7 @@ def _load_auth_site_definitions_inner() -> dict[str, Any]:
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
|
||||
auth_sites = SitesHelper().get_authsites() or {}
|
||||
definitions: dict[str, Any] = {}
|
||||
@@ -1843,7 +1879,7 @@ def _apply_local_system_config_inner(config_payload: dict[str, Any]) -> None:
|
||||
):
|
||||
system_config.set(SystemConfigKey.UserSiteAuthParams, site_auth_item)
|
||||
try:
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
|
||||
status, msg = SitesHelper().check_user(
|
||||
site_auth_item.get("site"), site_auth_item.get("params")
|
||||
|
||||
164
skills/create-moviepilot-skill/SKILL.md
Normal file
164
skills/create-moviepilot-skill/SKILL.md
Normal file
@@ -0,0 +1,164 @@
|
||||
---
|
||||
name: create-moviepilot-skill
|
||||
version: 1
|
||||
description: >-
|
||||
Use this skill when the user asks to create, scaffold, update, or review a
|
||||
MoviePilot agent skill. This includes adding a new built-in skill under the
|
||||
repository `skills/` directory, editing an existing built-in skill, writing
|
||||
`SKILL.md` frontmatter and workflow instructions, choosing `allowed-tools`,
|
||||
adding helper scripts when needed, and bumping the built-in skill `version`
|
||||
so changes can sync into `config/agent/skills`.
|
||||
allowed-tools: list_directory read_file write_file edit_file execute_command
|
||||
---
|
||||
|
||||
# Create MoviePilot Skill
|
||||
|
||||
This skill guides you through creating or updating a built-in MoviePilot agent
|
||||
skill in this repository.
|
||||
|
||||
## Scope
|
||||
|
||||
Use this workflow for repository built-in skills:
|
||||
|
||||
- Create or update files under `skills/<skill-id>/`
|
||||
- Commit the skill as part of the MoviePilot repository
|
||||
- Do not place the implementation only in `config/agent/skills` unless the user
|
||||
explicitly asks for a local override instead of a built-in skill
|
||||
|
||||
## MoviePilot-Specific Rules
|
||||
|
||||
- The repository root `skills/` directory is the bundled source of truth for
|
||||
built-in skills.
|
||||
- On agent startup, bundled skills are synced into `config/agent/skills`.
|
||||
- Sync overwrite depends on the `version` field in `SKILL.md`. If you update an
|
||||
existing built-in skill, increment `version`, or users may continue using an
|
||||
older copied version.
|
||||
- Keep the folder name and frontmatter `name` identical. Use lowercase letters,
|
||||
digits, and hyphens only.
|
||||
- Prefer extending an existing skill instead of creating an overlapping
|
||||
duplicate.
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Understand the Request
|
||||
|
||||
- Determine whether the user wants a new skill or a change to an existing one.
|
||||
- Extract the target task, likely trigger phrases, needed tools, and whether
|
||||
helper scripts are necessary.
|
||||
- If the goal is still ambiguous after reading the request and local context,
|
||||
ask one focused clarification question. Otherwise proceed with a reasonable
|
||||
default.
|
||||
|
||||
### Step 2: Check Existing Skills First
|
||||
|
||||
- Inspect the repository `skills/` directory before creating anything new.
|
||||
- If an existing skill already covers most of the workflow, update it instead of
|
||||
adding a near-duplicate.
|
||||
- Reuse the repository style: concise YAML frontmatter, trigger-rich
|
||||
description, and procedural body sections.
|
||||
|
||||
### Step 3: Choose the Skill ID and Path
|
||||
|
||||
- New built-in skill path: `skills/<skill-id>/SKILL.md`
|
||||
- Keep `<skill-id>` short, hyphen-case, and under 64 characters.
|
||||
- Use a verb-led or domain-led name that makes the trigger obvious, such as
|
||||
`transfer-failed-retry`, `moviepilot-api`, or `create-moviepilot-skill`.
|
||||
|
||||
### Step 4: Write Frontmatter Correctly
|
||||
|
||||
Use this shape:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: create-moviepilot-skill
|
||||
version: 1
|
||||
description: >-
|
||||
Explain what the skill does and exactly when to use it.
|
||||
allowed-tools: list_directory read_file write_file edit_file execute_command
|
||||
---
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- `description` is the primary trigger surface. Put concrete "when to use"
|
||||
scenarios there.
|
||||
- Include `version` for built-in skills. Increment it whenever you ship a new
|
||||
built-in revision.
|
||||
- Add `allowed-tools` when the workflow depends on a small, well-defined tool
|
||||
set.
|
||||
- Add `compatibility` only when environment constraints actually matter.
|
||||
|
||||
### Step 5: Write the Body
|
||||
|
||||
The body should contain:
|
||||
|
||||
- A short purpose statement
|
||||
- MoviePilot-specific rules or guardrails
|
||||
- A step-by-step workflow
|
||||
- Concrete examples of matching user requests
|
||||
- References to supporting files when they exist
|
||||
|
||||
Prefer:
|
||||
|
||||
- Imperative instructions
|
||||
- Concrete file paths
|
||||
- Examples aligned with actual MoviePilot conventions
|
||||
|
||||
Avoid:
|
||||
|
||||
- Generic theory that does not change execution
|
||||
- Large duplicated documentation
|
||||
- Extra files like `README.md` or `CHANGELOG.md` inside the skill directory
|
||||
|
||||
### Step 6: Add Supporting Files Only When They Help
|
||||
|
||||
- Add `scripts/` only when the same deterministic work would otherwise be
|
||||
rewritten repeatedly.
|
||||
- Keep helper files inside the same skill directory.
|
||||
- Reference helper paths explicitly from `SKILL.md`.
|
||||
- If the skill is instructions-only, keep it to a single `SKILL.md`.
|
||||
|
||||
### Step 7: Implement the Skill
|
||||
|
||||
For a new built-in skill:
|
||||
|
||||
1. Create `skills/<skill-id>/`
|
||||
2. Create `SKILL.md`
|
||||
3. Add helper scripts only if they are justified
|
||||
|
||||
For an existing built-in skill:
|
||||
|
||||
1. Edit `skills/<skill-id>/SKILL.md`
|
||||
2. Increment `version`
|
||||
3. Update helper files in the same directory if needed
|
||||
|
||||
### Step 8: Validate Before Finishing
|
||||
|
||||
- Re-read the frontmatter and confirm `name` matches the directory name.
|
||||
- Confirm `description` mentions real trigger scenarios.
|
||||
- If you changed an existing built-in skill, confirm `version` increased.
|
||||
- If possible, validate the file can be parsed by the MoviePilot skills loader.
|
||||
- Report the final path and note whether the agent needs a restart to sync the
|
||||
latest built-in skill into `config/agent/skills`.
|
||||
|
||||
## Minimal Example
|
||||
|
||||
User request:
|
||||
|
||||
`给 MoviePilot agent 加一个处理站点 Cookie 更新的内置技能`
|
||||
|
||||
Expected outcome:
|
||||
|
||||
- Create or update a directory such as `skills/update-site-cookie/`
|
||||
- Write `SKILL.md` with a trigger-rich `description`
|
||||
- Include only the tools needed for that workflow
|
||||
- Increment `version` when revising an existing built-in skill
|
||||
|
||||
## Final Checklist
|
||||
|
||||
- Is the skill under the repository `skills/` directory?
|
||||
- Does the folder name equal frontmatter `name`?
|
||||
- Does `description` clearly say when the skill should trigger?
|
||||
- Did you avoid duplicating an existing skill unnecessarily?
|
||||
- Did you increment `version` for built-in skill updates?
|
||||
- Did you keep the skill lean and procedural?
|
||||
29
tests/test_agent_add_subscribe_tool.py
Normal file
29
tests/test_agent_add_subscribe_tool.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
|
||||
|
||||
class TestAgentAddSubscribeTool(unittest.TestCase):
|
||||
def test_tv_subscription_without_season_reports_default_first_season(self):
|
||||
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
|
||||
new=AsyncMock(return_value=(1, "")),
|
||||
):
|
||||
result = asyncio.run(
|
||||
tool.run(
|
||||
title="Breaking Bad",
|
||||
year="2008",
|
||||
media_type="tv",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertIn("第1季", result)
|
||||
self.assertIn("默认按第一季订阅", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import (
|
||||
AskUserChoiceTool,
|
||||
UserChoiceOptionInput,
|
||||
)
|
||||
from app.agent.interaction import (
|
||||
from app.chain.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
|
||||
62
tests/test_agent_prompt_style.py
Normal file
62
tests/test_agent_prompt_style.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.agent.middleware.memory import MEMORY_ONBOARDING_PROMPT
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestAgentPromptStyle(unittest.TestCase):
|
||||
def test_agent_prompt_enforces_concise_professional_style(self):
|
||||
prompt = prompt_manager.get_agent_prompt()
|
||||
|
||||
self.assertIn("professional, concise, restrained", prompt)
|
||||
self.assertIn("Do NOT flatter the user", prompt)
|
||||
self.assertIn("NO praise, emotional cushioning", prompt)
|
||||
|
||||
def test_agent_prompt_defines_tv_subscription_default_season_rule(self):
|
||||
prompt = prompt_manager.get_agent_prompt()
|
||||
|
||||
self.assertIn(
|
||||
"omitting `season` means subscribe to season 1 only",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn(
|
||||
"call `add_subscribe` separately for each season",
|
||||
prompt,
|
||||
)
|
||||
|
||||
def test_non_verbose_prompt_requires_silence_until_all_tools_finish(self):
|
||||
with patch.object(settings, "AI_AGENT_VERBOSE", False):
|
||||
prompt = prompt_manager.get_agent_prompt()
|
||||
|
||||
self.assertIn(
|
||||
"[Important Instruction] STRICTLY ENFORCED:",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn(
|
||||
"DO NOT output any conversational text, explanations, progress updates, or acknowledgements before the first tool call or between tool calls",
|
||||
prompt,
|
||||
)
|
||||
self.assertIn(
|
||||
"Only then may you send one final user-facing reply",
|
||||
prompt,
|
||||
)
|
||||
|
||||
def test_verbose_prompt_does_not_inject_silence_until_tools_finish_rule(self):
|
||||
with patch.object(settings, "AI_AGENT_VERBOSE", True):
|
||||
prompt = prompt_manager.get_agent_prompt()
|
||||
|
||||
self.assertNotIn(
|
||||
"DO NOT output any conversational text, explanations, progress updates, or acknowledgements before the first tool call or between tool calls",
|
||||
prompt,
|
||||
)
|
||||
|
||||
def test_memory_onboarding_does_not_force_warm_intro(self):
|
||||
self.assertIn("Do NOT interrupt the current task", MEMORY_ONBOARDING_PROMPT)
|
||||
self.assertIn("Do NOT proactively greet warmly", MEMORY_ONBOARDING_PROMPT)
|
||||
self.assertNotIn("greet the user warmly", MEMORY_ONBOARDING_PROMPT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
106
tests/test_agent_session_status.py
Normal file
106
tests/test_agent_session_status.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.chain.message import MessageChain
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentSessionStatus(unittest.TestCase):
|
||||
def test_usage_middleware_records_usage_metadata(self):
|
||||
snapshots = []
|
||||
middleware = UsageMiddleware(on_usage=snapshots.append)
|
||||
request = ModelRequest(
|
||||
model=SimpleNamespace(
|
||||
model="gpt-4o-mini", profile={"max_input_tokens": 128000}
|
||||
),
|
||||
messages=[],
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
response = ModelResponse(
|
||||
result=[
|
||||
AIMessage(
|
||||
content="ok",
|
||||
usage_metadata={
|
||||
"input_tokens": 1200,
|
||||
"output_tokens": 300,
|
||||
"total_tokens": 1500,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def handler(_: ModelRequest):
|
||||
return response
|
||||
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertIs(result, response)
|
||||
self.assertEqual(len(snapshots), 1)
|
||||
self.assertEqual(snapshots[0]["model"], "gpt-4o-mini")
|
||||
self.assertEqual(snapshots[0]["context_window_tokens"], 128000)
|
||||
self.assertEqual(snapshots[0]["input_tokens"], 1200)
|
||||
self.assertEqual(snapshots[0]["output_tokens"], 300)
|
||||
self.assertEqual(snapshots[0]["total_tokens"], 1500)
|
||||
self.assertAlmostEqual(snapshots[0]["context_usage_ratio"], 1200 / 128000)
|
||||
|
||||
def test_remote_session_status_sends_usage_summary(self):
|
||||
chain = MessageChain()
|
||||
chain._user_sessions["10001"] = ("session-1", datetime.now())
|
||||
status = {
|
||||
"session_id": "session-1",
|
||||
"model": "gpt-4o-mini",
|
||||
"context_window_tokens": 128000,
|
||||
"last_input_tokens": 1200,
|
||||
"last_output_tokens": 300,
|
||||
"last_total_tokens": 1500,
|
||||
"last_context_usage_ratio": 1200 / 128000,
|
||||
"total_input_tokens": 4500,
|
||||
"total_output_tokens": 1500,
|
||||
"total_tokens": 6000,
|
||||
"model_call_count": 4,
|
||||
"last_updated_at": "2026-04-26 12:34:56",
|
||||
"is_processing": True,
|
||||
"pending_messages": 2,
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.chain.message.agent_manager.get_session_status",
|
||||
return_value=status,
|
||||
),
|
||||
patch.object(chain, "post_message") as post_message,
|
||||
):
|
||||
chain.remote_session_status(
|
||||
channel=MessageChannel.Telegram,
|
||||
userid="10001",
|
||||
source="telegram-test",
|
||||
)
|
||||
|
||||
notification = post_message.call_args.args[0]
|
||||
self.assertEqual(notification.title, "当前智能体会话状态")
|
||||
self.assertIn("session-1", notification.text)
|
||||
self.assertIn("gpt-4o-mini", notification.text)
|
||||
self.assertIn("1,200 / 128,000 (0.94%)", notification.text)
|
||||
self.assertIn("输入 4,500 / 输出 1,500 / 总计 6,000", notification.text)
|
||||
self.assertIn("运行中", notification.text)
|
||||
|
||||
def test_remote_session_status_handles_missing_session(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(chain, "post_message") as post_message:
|
||||
chain.remote_session_status(
|
||||
channel=MessageChannel.Telegram,
|
||||
userid="10001",
|
||||
source="telegram-test",
|
||||
)
|
||||
|
||||
notification = post_message.call_args.args[0]
|
||||
self.assertEqual(notification.title, "您当前没有活跃的智能体会话")
|
||||
120
tests/test_agent_summarization_streaming.py
Normal file
120
tests/test_agent_summarization_streaming.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
|
||||
import app.agent as agent_module
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
_llm_type = "openai-chat"
|
||||
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
self.profile = {"max_input_tokens": 64000}
|
||||
|
||||
|
||||
class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
def test_streaming_agent_uses_non_streaming_llm_for_summary(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
non_streaming_llm = _FakeLLM("non-streaming")
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
|
||||
),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
):
|
||||
agent._create_agent(streaming=True)
|
||||
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, SummarizationMiddleware)
|
||||
)
|
||||
|
||||
self.assertIs(captured["model"], main_llm)
|
||||
self.assertIs(summary_middleware.model, non_streaming_llm)
|
||||
|
||||
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
non_streaming_llm = _FakeLLM("non-streaming")
|
||||
captured: dict = {}
|
||||
|
||||
class _FakeToolSelectorMiddleware:
|
||||
def __init__(self, model, max_tools):
|
||||
self.model = model
|
||||
self.max_tools = max_tools
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
|
||||
),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module,
|
||||
"LLMToolSelectorMiddleware",
|
||||
_FakeToolSelectorMiddleware,
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
|
||||
):
|
||||
agent._create_agent(streaming=True)
|
||||
|
||||
tool_selector_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, _FakeToolSelectorMiddleware)
|
||||
)
|
||||
|
||||
self.assertIs(tool_selector_middleware.model, non_streaming_llm)
|
||||
|
||||
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
with (
|
||||
patch.object(agent, "_initialize_llm", return_value=main_llm),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
):
|
||||
agent._create_agent(streaming=False)
|
||||
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, SummarizationMiddleware)
|
||||
)
|
||||
|
||||
self.assertIs(captured["model"], main_llm)
|
||||
self.assertIs(summary_middleware.model, main_llm)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
113
tests/test_agent_tool_streaming.py
Normal file
113
tests/test_agent_tool_streaming.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import langchain.agents as langchain_agents
|
||||
|
||||
if not hasattr(langchain_agents, "create_agent"):
|
||||
langchain_agents.create_agent = lambda *args, **kwargs: None
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.schemas.message import MessageResponse
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class DummyTool(MoviePilotTool):
|
||||
name: str = "dummy_tool"
|
||||
description: str = "Dummy tool for streaming tests."
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
class TestAgentToolStreaming(unittest.TestCase):
|
||||
async def _run_tool(self, initial_buffer: str) -> tuple[str, str]:
|
||||
tool = DummyTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
if initial_buffer:
|
||||
handler.emit(initial_buffer)
|
||||
tool.set_stream_handler(handler)
|
||||
|
||||
with patch.object(settings, "AI_AGENT_VERBOSE", False):
|
||||
result = await tool._arun(explanation="run test tool")
|
||||
|
||||
buffered_message = await handler.take()
|
||||
return result, buffered_message
|
||||
|
||||
def test_non_verbose_tool_call_appends_newline_separator(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n")
|
||||
|
||||
def test_non_verbose_tool_call_does_not_duplicate_newline(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n")
|
||||
|
||||
def test_non_verbose_tool_call_keeps_empty_buffer_unchanged(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool(""))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "")
|
||||
|
||||
def test_flush_sends_direct_message_via_threadpool(self):
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Telegram.value
|
||||
handler._source = "telegram"
|
||||
handler._user_id = "10001"
|
||||
handler._username = "tester"
|
||||
handler._streaming_enabled = True
|
||||
handler.emit("hello")
|
||||
|
||||
with patch(
|
||||
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
|
||||
) as run_in_threadpool_mock:
|
||||
run_in_threadpool_mock.return_value = MessageResponse(
|
||||
message_id=1,
|
||||
chat_id=2,
|
||||
source="telegram",
|
||||
success=True,
|
||||
)
|
||||
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
|
||||
)
|
||||
self.assertTrue(handler.has_sent_message)
|
||||
|
||||
def test_flush_edits_message_via_threadpool(self):
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Telegram.value
|
||||
handler._streaming_enabled = True
|
||||
handler._message_response = MessageResponse(
|
||||
message_id=1,
|
||||
chat_id=2,
|
||||
source="telegram",
|
||||
success=True,
|
||||
)
|
||||
handler._sent_text = "hello"
|
||||
handler.emit("hello world")
|
||||
|
||||
with patch(
|
||||
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
|
||||
) as run_in_threadpool_mock:
|
||||
run_in_threadpool_mock.return_value = True
|
||||
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "edit_message"
|
||||
)
|
||||
self.assertEqual(handler._sent_text, "hello world")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
61
tests/test_execute_command_tool.py
Normal file
61
tests/test_execute_command_tool.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from app.agent.tools.impl.execute_command import (
|
||||
ExecuteCommandTool,
|
||||
MAX_OUTPUT_CHARS,
|
||||
)
|
||||
|
||||
|
||||
def _python_command(code: str) -> str:
|
||||
"""生成当前解释器可执行的 shell 命令,避免依赖系统 python 名称。"""
|
||||
args = [sys.executable, "-c", code]
|
||||
if os.name == "nt":
|
||||
return subprocess.list2cmdline(args)
|
||||
return " ".join(shlex.quote(arg) for arg in args)
|
||||
|
||||
|
||||
class TestExecuteCommandTool(unittest.TestCase):
|
||||
def _run_command(self, command: str, timeout: int = 60) -> str:
|
||||
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
return asyncio.run(tool.run(command=command, timeout=timeout))
|
||||
|
||||
def test_large_output_is_truncated_before_returning_to_agent(self):
|
||||
command = _python_command(
|
||||
"import sys; sys.stdout.write('x' * 200000); sys.stdout.flush()"
|
||||
)
|
||||
|
||||
result = self._run_command(command)
|
||||
|
||||
self.assertIn("输出内容过长,已截断", result)
|
||||
self.assertLess(len(result), MAX_OUTPUT_CHARS + 500)
|
||||
|
||||
def test_timeout_returns_partial_output_promptly(self):
|
||||
command = _python_command(
|
||||
"import time; print('started', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
started_at = time.monotonic()
|
||||
result = self._run_command(command, timeout=1)
|
||||
duration = time.monotonic() - started_at
|
||||
|
||||
self.assertLess(duration, 4)
|
||||
self.assertIn("命令执行超时", result)
|
||||
self.assertIn("started", result)
|
||||
|
||||
def test_timeout_is_capped(self):
|
||||
command = _python_command("print('ok')")
|
||||
|
||||
result = self._run_command(command, timeout=9999)
|
||||
|
||||
self.assertIn("timeout 参数超过上限", result)
|
||||
self.assertIn("ok", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
144
tests/test_langchain_deepseek_compat.py
Normal file
144
tests/test_langchain_deepseek_compat.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def _build_tool_call(name: str = "search", arguments: str = "{}"):
|
||||
return [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
"name": name,
|
||||
"args": {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class _FakeChatDeepSeek:
|
||||
def __init__(self, model_name: str, model_kwargs: dict | None = None):
|
||||
self.model_name = model_name
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
messages = []
|
||||
for message in input_:
|
||||
payload_message = {
|
||||
"role": message.type,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.type == "human":
|
||||
payload_message["role"] = "user"
|
||||
elif message.type == "ai":
|
||||
payload_message["role"] = "assistant"
|
||||
tool_calls = getattr(message, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload_message["tool_calls"] = tool_calls
|
||||
elif message.type == "tool":
|
||||
payload_message["role"] = "tool"
|
||||
payload_message["tool_call_id"] = message.tool_call_id
|
||||
messages.append(payload_message)
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
_ORIGINAL_GET_REQUEST_PAYLOAD = _FakeChatDeepSeek._get_request_payload
|
||||
|
||||
|
||||
sys.modules.pop("app.helper.llm", None)
|
||||
_stub_module(
|
||||
"app.core.config",
|
||||
settings=ModuleType("settings"),
|
||||
)
|
||||
sys.modules["app.core.config"].settings.LLM_PROVIDER = "deepseek"
|
||||
sys.modules["app.core.config"].settings.LLM_MODEL = "deepseek-v4-pro"
|
||||
sys.modules["app.core.config"].settings.LLM_API_KEY = "sk-test"
|
||||
sys.modules["app.core.config"].settings.LLM_BASE_URL = "https://api.deepseek.com"
|
||||
sys.modules["app.core.config"].settings.LLM_THINKING_LEVEL = None
|
||||
sys.modules["app.core.config"].settings.LLM_TEMPERATURE = 0.1
|
||||
sys.modules["app.core.config"].settings.LLM_MAX_CONTEXT_TOKENS = 64
|
||||
sys.modules["app.core.config"].settings.PROXY_HOST = None
|
||||
_stub_module("app.log", logger=_DummyLogger())
|
||||
_stub_module("langchain_deepseek", ChatDeepSeek=_FakeChatDeepSeek)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[1] / "app" / "helper" / "llm.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_module_for_deepseek_compat", module_path)
|
||||
llm_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(llm_module)
|
||||
|
||||
|
||||
class DeepSeekCompatPatchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_FakeChatDeepSeek._get_request_payload = _ORIGINAL_GET_REQUEST_PAYLOAD
|
||||
if hasattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched"):
|
||||
delattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched")
|
||||
llm_module._patch_deepseek_reasoning_content_support()
|
||||
|
||||
def test_injects_reasoning_content_for_assistant_tool_calls(self):
|
||||
llm = _FakeChatDeepSeek("deepseek-v4-pro")
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=_build_tool_call(),
|
||||
additional_kwargs={"reasoning_content": "先调用天气工具"},
|
||||
),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertEqual(
|
||||
payload["messages"][1]["reasoning_content"],
|
||||
"先调用天气工具",
|
||||
)
|
||||
|
||||
def test_falls_back_to_empty_reasoning_content_when_missing(self):
|
||||
llm = _FakeChatDeepSeek("deepseek-v4-flash")
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(content="", tool_calls=_build_tool_call()),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertIn("reasoning_content", payload["messages"][1])
|
||||
self.assertEqual(payload["messages"][1]["reasoning_content"], "")
|
||||
|
||||
def test_skips_injection_when_thinking_is_disabled(self):
|
||||
llm = _FakeChatDeepSeek(
|
||||
"deepseek-v4-pro",
|
||||
model_kwargs={"extra_body": {"thinking": {"type": "disabled"}}},
|
||||
)
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=_build_tool_call(),
|
||||
additional_kwargs={"reasoning_content": "先调用天气工具"},
|
||||
),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertNotIn("reasoning_content", payload["messages"][1])
|
||||
@@ -38,7 +38,7 @@ _stub_module(
|
||||
LLM_MODEL="global-model",
|
||||
LLM_API_KEY="global-key",
|
||||
LLM_BASE_URL="https://global.example.com",
|
||||
LLM_DISABLE_THINKING=False,
|
||||
LLM_THINKING_LEVEL=None,
|
||||
LLM_TEMPERATURE=0.1,
|
||||
LLM_MAX_CONTEXT_TOKENS=64,
|
||||
PROXY_HOST=None,
|
||||
@@ -83,7 +83,9 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
streaming=False,
|
||||
provider="deepseek",
|
||||
model="deepseek-chat",
|
||||
thinking_level=None,
|
||||
disable_thinking=None,
|
||||
reasoning_effort=None,
|
||||
api_key="sk-test",
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
@@ -138,7 +140,77 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
|
||||
def test_get_llm_uses_openai_reasoning_effort_none(self):
|
||||
def test_get_llm_uses_deepseek_thinking_level_controls(self):
|
||||
calls = []
|
||||
patch_calls = []
|
||||
|
||||
class _FakeChatDeepSeek:
|
||||
def __init__(self, **kwargs):
|
||||
calls.append(kwargs)
|
||||
self.model = kwargs["model"]
|
||||
self.profile = None
|
||||
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)},
|
||||
), patch.object(
|
||||
llm_module,
|
||||
"_patch_deepseek_reasoning_content_support",
|
||||
side_effect=lambda: patch_calls.append(True),
|
||||
):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="deepseek",
|
||||
model="deepseek-v4-pro",
|
||||
thinking_level="xhigh",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(
|
||||
calls[0].get("extra_body"),
|
||||
{"thinking": {"type": "enabled"}},
|
||||
)
|
||||
self.assertEqual(patch_calls, [True])
|
||||
self.assertEqual(calls[0].get("reasoning_effort"), "max")
|
||||
self.assertEqual(calls[0].get("api_base"), "https://api.deepseek.com")
|
||||
|
||||
def test_get_llm_disables_deepseek_thinking_via_thinking_level(self):
|
||||
calls = []
|
||||
patch_calls = []
|
||||
|
||||
class _FakeChatDeepSeek:
|
||||
def __init__(self, **kwargs):
|
||||
calls.append(kwargs)
|
||||
self.model = kwargs["model"]
|
||||
self.profile = None
|
||||
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)},
|
||||
), patch.object(
|
||||
llm_module,
|
||||
"_patch_deepseek_reasoning_content_support",
|
||||
side_effect=lambda: patch_calls.append(True),
|
||||
):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="deepseek",
|
||||
model="deepseek-v4-flash",
|
||||
thinking_level="off",
|
||||
api_key="sk-test",
|
||||
base_url="https://proxy.example.com",
|
||||
)
|
||||
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(
|
||||
calls[0].get("extra_body"),
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
self.assertEqual(patch_calls, [True])
|
||||
self.assertIsNone(calls[0].get("reasoning_effort"))
|
||||
self.assertEqual(calls[0].get("api_base"), "https://proxy.example.com")
|
||||
|
||||
def test_get_llm_uses_openai_reasoning_effort_none_for_off(self):
|
||||
calls = []
|
||||
|
||||
class _FakeChatOpenAI:
|
||||
@@ -154,7 +226,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="openai",
|
||||
model="gpt-5-mini",
|
||||
disable_thinking=True,
|
||||
thinking_level="off",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
@@ -162,6 +234,30 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].get("reasoning_effort"), "none")
|
||||
|
||||
def test_get_llm_maps_unified_max_to_openai_xhigh(self):
|
||||
calls = []
|
||||
|
||||
class _FakeChatOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
calls.append(kwargs)
|
||||
self.model = kwargs["model"]
|
||||
self.profile = None
|
||||
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
|
||||
):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="openai",
|
||||
model="gpt-5.4",
|
||||
thinking_level="max",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].get("reasoning_effort"), "xhigh")
|
||||
|
||||
def test_get_llm_uses_gemini_builtin_thinking_controls(self):
|
||||
calls = []
|
||||
|
||||
@@ -182,7 +278,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="google",
|
||||
model="gemini-2.5-flash",
|
||||
disable_thinking=True,
|
||||
thinking_level="off",
|
||||
api_key="sk-test",
|
||||
base_url=None,
|
||||
)
|
||||
@@ -191,6 +287,35 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
self.assertEqual(calls[0].get("thinking_budget"), 0)
|
||||
self.assertFalse(calls[0].get("include_thoughts"))
|
||||
|
||||
def test_get_llm_uses_gemini_3_thinking_level_controls(self):
|
||||
calls = []
|
||||
|
||||
class _FakeChatGoogleGenerativeAI:
|
||||
def __init__(self, **kwargs):
|
||||
calls.append(kwargs)
|
||||
self.model = kwargs["model"]
|
||||
self.profile = None
|
||||
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
"langchain_google_genai": SimpleNamespace(
|
||||
ChatGoogleGenerativeAI=_FakeChatGoogleGenerativeAI
|
||||
)
|
||||
},
|
||||
):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="google",
|
||||
model="gemini-3.1-flash",
|
||||
thinking_level="xhigh",
|
||||
api_key="sk-test",
|
||||
base_url=None,
|
||||
)
|
||||
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].get("thinking_level"), "high")
|
||||
self.assertFalse(calls[0].get("include_thoughts"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
158
tests/test_media_interaction.py
Normal file
158
tests/test_media_interaction.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi"))
|
||||
setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list)
|
||||
sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
|
||||
setattr(sys.modules["transmission_rpc"], "File", object)
|
||||
sys.modules.setdefault("psutil", ModuleType("psutil"))
|
||||
|
||||
from app.chain.interaction import MediaInteractionChain, media_interaction_manager
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestMediaInteraction(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
media_interaction_manager.clear()
|
||||
|
||||
@staticmethod
|
||||
def _build_meta(name: str) -> MetaBase:
|
||||
meta = MetaBase(name)
|
||||
meta.name = name
|
||||
meta.begin_season = 1
|
||||
return meta
|
||||
|
||||
def test_message_routes_text_reply_to_media_interaction_before_ai(self):
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
)
|
||||
self.assertIsNotNone(request)
|
||||
|
||||
with patch.object(chain, "_record_user_message"), patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_text_interaction",
|
||||
return_value=True,
|
||||
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="1",
|
||||
)
|
||||
|
||||
handle_text.assert_called_once()
|
||||
handle_ai.assert_not_called()
|
||||
|
||||
def test_callback_routes_to_media_interaction_chain(self):
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_callback_interaction",
|
||||
return_value=True,
|
||||
) as handle_callback:
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:media:{request.request_id}:page-next",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
handle_callback.assert_called_once()
|
||||
|
||||
def test_media_interaction_starts_search_and_posts_media_list(self):
|
||||
chain = MediaInteractionChain()
|
||||
meta = self._build_meta("星际穿越")
|
||||
medias = [
|
||||
MediaInfo(title="星际穿越", year="2014"),
|
||||
MediaInfo(title="Interstellar", year="2014"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.chain.interaction.MediaChain.search",
|
||||
return_value=(meta, medias),
|
||||
), patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_text_interaction(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="星际穿越",
|
||||
)
|
||||
|
||||
self.assertTrue(handled)
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
self.assertTrue(notification.buttons)
|
||||
self.assertTrue(
|
||||
notification.buttons[0][0]["callback_data"].startswith("media:")
|
||||
)
|
||||
|
||||
request = media_interaction_manager.get_by_user("10001")
|
||||
self.assertIsNotNone(request)
|
||||
self.assertEqual(request.action, "Search")
|
||||
self.assertEqual(len(request.items), 2)
|
||||
|
||||
def test_media_interaction_legacy_page_callback_updates_existing_request(self):
|
||||
chain = MediaInteractionChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[
|
||||
MediaInfo(title=f"资源 {index}", year="2024")
|
||||
for index in range(1, 11)
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_callback_interaction(
|
||||
callback_data="page_n",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
original_message_id=123,
|
||||
original_chat_id="456",
|
||||
)
|
||||
|
||||
self.assertTrue(handled)
|
||||
self.assertEqual(request.page, 1)
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
self.assertEqual(notification.original_message_id, 123)
|
||||
self.assertEqual(notification.original_chat_id, "456")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user