Compare commits

...

34 Commits

Author SHA1 Message Date
jxxghp
1ded58adbb fix: adapt audiences user data parser 2026-04-27 12:56:45 +08:00
jxxghp
019a077407 Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-27 11:35:44 +08:00
PKC278
0f190057d3 fix #5528 2026-04-27 11:35:44 +08:00
jxxghp
840c8f7298 更新 APP_VERSION 至 v2.10.6 2026-04-27 11:32:39 +08:00
jxxghp
6a6bcf59a0 增强 execute_command 工具:支持输出截断、并发限制与进程组清理,新增单元测试 2026-04-27 10:05:25 +08:00
jxxghp
323844b26d revert execute_command streaming changes
Restore the previous subprocess handling for execute_command and drop the new command streaming test so agent startup is unblocked.
2026-04-27 08:12:37 +08:00
jxxghp
140d224a9a fix agent stream blocking during command execution
Offload synchronous message edits from the event loop and stream subprocess output so long-running commands stay responsive.
2026-04-27 07:57:32 +08:00
jxxghp
7bc032d17c Revert Telegram duplicate edit fix 2026-04-27 07:36:13 +08:00
jxxghp
2df476dbff Fix Telegram duplicate message edits 2026-04-27 07:17:58 +08:00
jxxghp
bae086d8b8 更新 __init__.py 2026-04-27 06:57:18 +08:00
jxxghp
221eb21694 refine internal middleware llm usage for streaming agents
Use a non-streaming model for middleware-only calls so internal outputs do not leak into user streams and model-based middleware stays consistent.
2026-04-27 06:55:41 +08:00
jxxghp
4208c79d72 refine tool提示语为更简洁风格,补充last_buffer_char属性及非VERBOSE模式流式输出换行逻辑,新增工具流式分隔符单元测试 2026-04-26 11:15:11 +08:00
jxxghp
90245a13e1 refine non-verbose prompt wording 2026-04-26 08:54:07 +08:00
jxxghp
b5979b9b09 refine agent subscription defaults and silent tool prompts 2026-04-26 08:51:56 +08:00
jxxghp
0277288a41 feat: add agent session usage status reporting
Track per-session model and token usage so users can inspect context pressure and cumulative usage with /session_status.
2026-04-26 08:19:05 +08:00
jxxghp
79bfeaf2af 移除工具调用前的流重置,保留模型思考文本可见 2026-04-25 23:12:34 +08:00
jxxghp
4fe41ba5e9 更新 base.py 2026-04-25 22:16:15 +08:00
jxxghp
14d6e2febc Refine agent prompts for concise professional replies 2026-04-25 22:04:35 +08:00
jxxghp
97c7e71207 更新 Agent Prompt.txt 2026-04-25 21:51:47 +08:00
jxxghp
8f29a218ea chore: bump version to v2.10.5 2026-04-25 12:55:33 +08:00
jxxghp
4fd5aa3eb6 fix: improve DeepSeek reasoning_content payload handling and update langchain dependencies 2026-04-25 12:46:21 +08:00
jxxghp
bfc27d151c 更新 ask_user_choice.py 2026-04-25 11:36:36 +08:00
jxxghp
f2b56b8f40 更新 ask_user_choice.py 2026-04-25 11:35:32 +08:00
jxxghp
a05ffc07d4 refactor: remove legacy LLM_DISABLE_THINKING and LLM_REASONING_EFFORT config, unify thinking_level handling
- Eliminate support for LLM_DISABLE_THINKING and LLM_REASONING_EFFORT in config, code, and tests
- Simplify LLM thinking level logic to rely solely on LLM_THINKING_LEVEL
- Refactor LLMHelper and related endpoints to remove legacy parameter handling
- Update system API and test utilities to match new configuration structure
- Minor code cleanup and formatting improvements
2026-04-25 10:42:03 +08:00
jxxghp
4a81417fb7 fix: preserve deepseek reasoning content in tool loops 2026-04-25 09:37:01 +08:00
jxxghp
c7fa3dc863 feat: unify llm thinking level controls 2026-04-24 19:50:23 +08:00
jxxghp
28f9756dd6 feat: improve skill instructions with highlighted command formatting 2026-04-22 18:12:21 +08:00
jxxghp
4bffe2cff1 chore: bump version to v2.10.4 2026-04-22 18:02:28 +08:00
jxxghp
fca478f1d8 feat: support custom skill sources in /skills 2026-04-22 18:00:57 +08:00
Sebastian
097dff13a3 feat: add ai-compatible API endpoints 2026-04-22 17:21:43 +08:00
jxxghp
460b386004 feat: add searchable skills marketplace 2026-04-22 16:49:42 +08:00
jxxghp
89bf89c02d feat: add clawhub skill registry source 2026-04-22 16:22:10 +08:00
jxxghp
cefb60ba2c refactor: unify message interactions 2026-04-22 15:18:04 +08:00
jxxghp
8c78627647 feat: add skills marketplace management 2026-04-22 14:55:00 +08:00
105 changed files with 8406 additions and 1533 deletions

View File

@@ -4,7 +4,8 @@ import re
import traceback
import uuid
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from langchain.agents import create_agent
from langchain.agents.middleware import (
@@ -24,6 +25,7 @@ from app.agent.middleware.jobs import JobsMiddleware
from app.agent.middleware.memory import MemoryMiddleware
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
from app.agent.middleware.skills import SkillsMiddleware
from app.agent.middleware.usage import UsageMiddleware
from app.agent.prompt import prompt_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
@@ -41,6 +43,39 @@ class AgentChain(ChainBase):
pass
@dataclass
class _SessionUsageSnapshot:
model: Optional[str] = None
context_window_tokens: Optional[int] = None
last_input_tokens: int = 0
last_output_tokens: int = 0
last_total_tokens: int = 0
last_context_usage_ratio: Optional[float] = None
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
model_call_count: int = 0
last_updated_at: Optional[datetime] = None
def to_dict(self, session_id: str) -> dict[str, Any]:
return {
"session_id": session_id,
"model": self.model,
"context_window_tokens": self.context_window_tokens,
"last_input_tokens": self.last_input_tokens,
"last_output_tokens": self.last_output_tokens,
"last_total_tokens": self.last_total_tokens,
"last_context_usage_ratio": self.last_context_usage_ratio,
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"total_tokens": self.total_tokens,
"model_call_count": self.model_call_count,
"last_updated_at": self.last_updated_at.strftime("%Y-%m-%d %H:%M:%S")
if self.last_updated_at
else None,
}
class _ThinkTagStripper:
"""
流式剥离 <think>...</think> 标签的辅助类。
@@ -73,7 +108,7 @@ class _ThinkTagStripper:
on_output(self.buffer[:start_idx])
emitted = True
self.in_think_tag = True
self.buffer = self.buffer[start_idx + 7:]
self.buffer = self.buffer[start_idx + 7 :]
else:
# 检查是否以 <think> 的不完整前缀结尾
partial_match = False
@@ -93,7 +128,7 @@ class _ThinkTagStripper:
end_idx = self.buffer.find("</think>")
if end_idx != -1:
self.in_think_tag = False
self.buffer = self.buffer[end_idx + 8:]
self.buffer = self.buffer[end_idx + 8 :]
else:
# 检查是否以 </think> 的不完整前缀结尾
partial_match = False
@@ -138,10 +173,92 @@ class MoviePilotAgent:
self.force_streaming = False
self.suppress_user_reply = False
self._streamed_output = ""
self._session_usage = _SessionUsageSnapshot()
# 流式token管理
self.stream_handler = StreamingHandler()
@staticmethod
def _coerce_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
@classmethod
def _get_model_name(cls, llm: Any) -> Optional[str]:
return (
getattr(llm, "model", None)
or getattr(llm, "model_name", None)
or getattr(llm, "model_id", None)
)
@classmethod
def _get_context_window_tokens(cls, llm: Any) -> Optional[int]:
profile = getattr(llm, "profile", None)
if not profile:
return None
if isinstance(profile, dict):
return cls._coerce_int(
profile.get("max_input_tokens") or profile.get("input_token_limit")
)
return cls._coerce_int(
getattr(profile, "max_input_tokens", None)
or getattr(profile, "input_token_limit", None)
)
def _sync_model_profile(self, llm: Any) -> None:
model_name = self._get_model_name(llm)
context_window_tokens = self._get_context_window_tokens(llm)
if model_name:
self._session_usage.model = model_name
if context_window_tokens:
self._session_usage.context_window_tokens = context_window_tokens
def _record_usage(self, usage: dict[str, Any]) -> None:
if not usage:
return
model_name = usage.get("model")
context_window_tokens = self._coerce_int(usage.get("context_window_tokens"))
if model_name:
self._session_usage.model = model_name
if context_window_tokens:
self._session_usage.context_window_tokens = context_window_tokens
self._session_usage.model_call_count += 1
self._session_usage.last_updated_at = datetime.now()
if not usage.get("has_usage"):
return
input_tokens = self._coerce_int(usage.get("input_tokens")) or 0
output_tokens = self._coerce_int(usage.get("output_tokens")) or 0
total_tokens = self._coerce_int(usage.get("total_tokens"))
if total_tokens is None:
total_tokens = input_tokens + output_tokens
self._session_usage.last_input_tokens = input_tokens
self._session_usage.last_output_tokens = output_tokens
self._session_usage.last_total_tokens = total_tokens
self._session_usage.last_context_usage_ratio = usage.get("context_usage_ratio")
self._session_usage.total_input_tokens += input_tokens
self._session_usage.total_output_tokens += output_tokens
self._session_usage.total_tokens += total_tokens
def get_session_status(self) -> dict[str, Any]:
if not self._session_usage.model:
self._session_usage.model = settings.LLM_MODEL
if not self._session_usage.context_window_tokens:
self._session_usage.context_window_tokens = (
settings.LLM_MAX_CONTEXT_TOKENS * 1000
if settings.LLM_MAX_CONTEXT_TOKENS
else None
)
return self._session_usage.to_dict(self.session_id)
@property
def is_background(self) -> bool:
"""
@@ -258,6 +375,12 @@ class MoviePilotAgent:
# LLM 模型(用于 agent 执行)
llm = self._initialize_llm(streaming=streaming)
self._sync_model_profile(llm)
# 为中间件内部模型调用准备非流式 LLM避免与用户流式回复复用同一实例。
non_streaming_llm = (
llm if not streaming else self._initialize_llm(streaming=False)
)
# 工具列表
tools = self._initialize_tools()
@@ -279,8 +402,12 @@ class MoviePilotAgent:
ActivityLogMiddleware(
activity_dir=str(settings.CONFIG_PATH / "agent" / "activity"),
),
# 用量统计
UsageMiddleware(on_usage=self._record_usage),
# 上下文压缩
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
SummarizationMiddleware(
model=non_streaming_llm, trigger=("fraction", 0.85)
),
# 错误工具调用修复
PatchToolCallsMiddleware(),
]
@@ -289,7 +416,8 @@ class MoviePilotAgent:
if settings.LLM_MAX_TOOLS > 0:
middlewares.append(
LLMToolSelectorMiddleware(
model=llm, max_tools=settings.LLM_MAX_TOOLS
model=non_streaming_llm,
max_tools=settings.LLM_MAX_TOOLS,
)
)
@@ -371,10 +499,6 @@ class MoviePilotAgent:
:param on_token: 收到有效 token 时的回调
"""
stripper = _ThinkTagStripper()
# 非VERBOSE模式下跟踪当前langgraph_step以检测中间步骤的模型输出
# 当模型在工具调用之前输出的"计划/思考"文本会在检测到tool_call时被清除
current_model_step = -1
has_emitted_in_step = False
async for chunk in agent.astream(
messages,
@@ -388,25 +512,13 @@ class MoviePilotAgent:
if not token or not hasattr(token, "tool_call_chunks"):
continue
# 获取当前步骤信息
step = metadata.get("langgraph_step", -1) if metadata else -1
if token.tool_call_chunks:
# 检测到工具调用token说明当前步骤是中间步骤
# 非VERBOSE模式下清除该步骤之前输出的"计划/思考"文本
if not settings.AI_AGENT_VERBOSE and has_emitted_in_step:
self.stream_handler.reset()
stripper.reset()
has_emitted_in_step = False
# 清除 stripper 内部缓冲中可能残留的 <think> 标签中间状态
stripper.reset()
continue
# 以下处理纯文本tokentool_call_chunks为空
# 检测步骤变化重置步骤内emit跟踪
if step != current_model_step:
current_model_step = step
has_emitted_in_step = False
# 跳过模型思考/推理内容(如 DeepSeek R1 的 reasoning_content
additional = getattr(token, "additional_kwargs", None)
if additional and additional.get("reasoning_content"):
@@ -416,8 +528,7 @@ class MoviePilotAgent:
# content 可能是字符串或内容块列表,过滤掉思考类型的块
content = self._extract_text_content(token.content)
if content:
if stripper.process(content, on_token):
has_emitted_in_step = True
stripper.process(content, on_token)
stripper.flush(on_token)
@@ -457,7 +568,10 @@ class MoviePilotAgent:
agent=agent,
messages={"messages": messages},
config=agent_config,
on_token=lambda token: (self.stream_handler.emit(token), self._emit_output(token)),
on_token=lambda token: (
self.stream_handler.emit(token),
self._emit_output(token),
),
)
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
@@ -622,6 +736,37 @@ class AgentManager:
# 重试整理缓冲区锁
self._retry_transfer_lock = asyncio.Lock()
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""获取会话当前模型与 token 使用状态。"""
agent = self.active_agents.get(session_id)
if agent:
status = agent.get_session_status()
else:
status = {
"session_id": session_id,
"model": settings.LLM_MODEL,
"context_window_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000
if settings.LLM_MAX_CONTEXT_TOKENS
else None,
"last_input_tokens": 0,
"last_output_tokens": 0,
"last_total_tokens": 0,
"last_context_usage_ratio": None,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_tokens": 0,
"model_call_count": 0,
"last_updated_at": None,
}
queue = self._session_queues.get(session_id)
status["pending_messages"] = queue.qsize() if queue else 0
status["is_processing"] = (
session_id in self._session_workers
and not self._session_workers[session_id].done()
)
return status
@staticmethod
async def initialize():
"""
@@ -1004,7 +1149,6 @@ class AgentManager:
)
try:
await self.process_message(
session_id=session_id,
user_id=user_id,

View File

@@ -2,6 +2,8 @@ import asyncio
import threading
from typing import Optional, Tuple
from fastapi.concurrency import run_in_threadpool
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -180,7 +182,7 @@ class StreamingHandler:
# 检查是否所有缓冲内容都已发送
with self._lock:
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
current_msg_text = self._buffer[self._msg_start_offset :]
current_msg_text = self._buffer[self._msg_start_offset:]
all_sent = (
self._message_response is not None
and self._sent_text
@@ -246,7 +248,7 @@ class StreamingHandler:
"""
with self._lock:
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
current_text = self._buffer[self._msg_start_offset :]
current_text = self._buffer[self._msg_start_offset:]
if not current_text or current_text == self._sent_text:
# 没有新内容需要刷新
return
@@ -256,7 +258,8 @@ class StreamingHandler:
try:
if self._message_response is None:
# 第一次发送:发送新消息并获取 message_id
response = chain.send_direct_message(
response = await run_in_threadpool(
chain.send_direct_message,
Notification(
channel=self._channel,
source=self._source,
@@ -264,7 +267,7 @@ class StreamingHandler:
username=self._username,
title=self._title,
text=current_text,
)
),
)
if response and response.success and response.message_id:
self._message_response = response
@@ -291,13 +294,14 @@ class StreamingHandler:
)
with self._lock:
self._msg_start_offset += len(self._sent_text)
current_text = self._buffer[self._msg_start_offset :]
current_text = self._buffer[self._msg_start_offset:]
self._message_response = None
self._sent_text = ""
# 如果偏移后还有新内容,立即发送为新消息
if current_text:
response = chain.send_direct_message(
response = await run_in_threadpool(
chain.send_direct_message,
Notification(
channel=self._channel,
source=self._source,
@@ -305,7 +309,7 @@ class StreamingHandler:
username=self._username,
title=self._title,
text=current_text,
)
),
)
if response and response.success and response.message_id:
self._message_response = response
@@ -324,7 +328,8 @@ class StreamingHandler:
except (ValueError, KeyError):
return
success = chain.edit_message(
success = await run_in_threadpool(
chain.edit_message,
channel=channel_enum,
source=self._message_response.source,
message_id=self._message_response.message_id,
@@ -360,3 +365,11 @@ class StreamingHandler:
是否已经通过流式输出发送过消息(当前轮次)
"""
return self._message_response is not None
@property
def last_buffer_char(self) -> str:
"""
返回当前缓冲区最后一个字符;缓冲区为空时返回空字符串。
"""
with self._lock:
return self._buffer[-1:] if self._buffer else ""

View File

@@ -1,107 +0,0 @@
"""Agent 客户端交互请求管理。"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from threading import Lock
from typing import Dict, List, Optional
import uuid
@dataclass(frozen=True)
class AgentInteractionOption:
"""交互选项。"""
label: str
value: str
@dataclass
class PendingAgentInteraction:
"""待处理的 Agent 客户端交互请求。"""
request_id: str
session_id: str
user_id: str
channel: Optional[str]
source: Optional[str]
username: Optional[str]
title: Optional[str]
prompt: str
options: List[AgentInteractionOption]
created_at: datetime = field(default_factory=datetime.now)
class AgentInteractionManager:
"""管理 Agent 发起的客户端交互请求。"""
_ttl = timedelta(hours=24)
def __init__(self):
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
self._lock = Lock()
def _cleanup_locked(self):
expire_before = datetime.now() - self._ttl
expired_ids = [
request_id
for request_id, request in self._pending_interactions.items()
if request.created_at < expire_before
]
for request_id in expired_ids:
self._pending_interactions.pop(request_id, None)
def create_request(
self,
session_id: str,
user_id: str,
channel: Optional[str],
source: Optional[str],
username: Optional[str],
title: Optional[str],
prompt: str,
options: List[AgentInteractionOption],
) -> PendingAgentInteraction:
with self._lock:
self._cleanup_locked()
request_id = uuid.uuid4().hex[:12]
while request_id in self._pending_interactions:
request_id = uuid.uuid4().hex[:12]
request = PendingAgentInteraction(
request_id=request_id,
session_id=session_id,
user_id=str(user_id),
channel=channel,
source=source,
username=username,
title=title,
prompt=prompt,
options=options,
)
self._pending_interactions[request_id] = request
return request
def resolve(
self,
request_id: str,
option_index: int,
user_id: Optional[str] = None,
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
with self._lock:
self._cleanup_locked()
request = self._pending_interactions.get(request_id)
if not request:
return None
if user_id is not None and str(request.user_id) != str(user_id):
return None
if option_index < 1 or option_index > len(request.options):
return None
option = request.options[option_index - 1]
self._pending_interactions.pop(request_id, None)
return request, option
def clear(self):
with self._lock:
self._pending_interactions.clear()
agent_interaction_manager = AgentInteractionManager()

View File

@@ -124,34 +124,29 @@ Default memory file: {memory_file}
</agent_memory>
<memory_onboarding>
**IMPORTANT — First-time user detected!**
First-time user detected.
The memory directory is currently empty. This means this is likely the user's first interaction, or their preferences have been reset.
The memory directory is currently empty. This likely means the user has no saved long-term preferences yet.
**Your MANDATORY first action in this conversation:**
Before doing ANYTHING else (before answering questions, before calling tools, before performing any task), you MUST proactively greet the user warmly and ask them about their preferences so you can provide personalized service going forward. Specifically, ask about:
**Behavior requirements:**
- Do NOT interrupt the current task just to collect preferences.
- Do NOT proactively greet warmly, build rapport, or ask a long onboarding questionnaire.
- Default to a concise, professional style until the user states a preference.
- Only ask for preferences when they are directly useful for the current task, or when a short follow-up question at the end would clearly help future interactions.
1. **How to address the user** — Ask what name or nickname they'd like you to call them (e.g., a real name, a nickname, or a fun title). This is the top priority for building a personal connection.
2. **Communication style preference** — Do they prefer a cute/playful tone (with emojis), a formal/professional tone, a concise/minimalist style, or something else?
3. **Media preferences** — What types of media do they primarily care about? (e.g., movies, TV shows, anime, documentaries, etc.)
4. **Quality preferences** — Do they have preferred video quality (4K, 1080p), codecs (H.265, H.264), or subtitle language preferences?
5. **Any other special requests** — Anything else they'd like you to always keep in mind?
**What to collect when useful:**
- Preferred communication style
- Media interests
- Quality / codec / subtitle preferences
- Any standing rules the user wants you to follow
**After the user replies**, you MUST immediately:
1. Use the `write_file` tool to save ALL their preferences to the memory file at: `{memory_file}`
2. Format the memory file in clean Markdown with clear sections (e.g., `## User Profile`, `## Communication Style`, `## Media Preferences`, etc.)
3. The `## User Profile` section MUST include the user's preferred name/nickname at the top
4. Only AFTER saving the preferences, proceed to help with whatever the user originally asked about (if anything)
5. From this point on, always address the user by their preferred name/nickname in conversations
6. You may also create additional `.md` files in the memory directory (`{memory_dir}`) for different topics as needed.
**When the user provides lasting preferences**, you MUST promptly save them to `{memory_file}` using `write_file` or `edit_file`.
**If the user skips the preference questions** and directly asks you to do something:
- Go ahead and help them with their request first
- But still ask about their preferences naturally at the end of the interaction
- Save whatever you learn about them (implicit or explicit) to the memory file
**Example onboarding flow:**
The greeting should introduce yourself, explain this is the first meeting, and ask the above questions in a numbered list. Adapt the tone to your persona defined in the base system prompt.
**Memory format requirements:**
- Use clean Markdown with short sections.
- Record only durable preferences and working rules.
- Do NOT invent personal details or preferred names.
- Do NOT force use of a nickname or personalized greeting.
</memory_onboarding>
<memory_guidelines>

View File

@@ -0,0 +1,184 @@
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import AIMessage
from app.log import logger
class UsageMiddleware(AgentMiddleware):
"""记录模型调用 usage 信息并回传给外部会话。"""
def __init__(
self,
*,
on_usage: Callable[[dict[str, Any]], None] | None = None,
) -> None:
self.on_usage = on_usage
@staticmethod
def _coerce_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
@classmethod
def _lookup_int(cls, container: Any, *keys: str) -> int | None:
if not container:
return None
getter = getattr(container, "get", None)
if callable(getter):
for key in keys:
value = getter(key)
if value is not None:
return cls._coerce_int(value)
for key in keys:
value = getattr(container, key, None)
if value is not None:
return cls._coerce_int(value)
return None
@classmethod
def _extract_model_name(cls, model: Any) -> str | None:
return (
getattr(model, "model", None)
or getattr(model, "model_name", None)
or getattr(model, "model_id", None)
)
@classmethod
def _extract_context_window_tokens(cls, model: Any) -> int | None:
profile = getattr(model, "profile", None)
if not profile:
return None
return cls._lookup_int(profile, "max_input_tokens", "input_token_limit")
@classmethod
def _extract_usage(cls, ai_message: AIMessage) -> dict[str, Any]:
usage_metadata = getattr(ai_message, "usage_metadata", None)
input_tokens = cls._lookup_int(usage_metadata, "input_tokens")
output_tokens = cls._lookup_int(usage_metadata, "output_tokens")
total_tokens = cls._lookup_int(usage_metadata, "total_tokens")
response_metadata = getattr(ai_message, "response_metadata", None) or {}
token_usage = (
response_metadata.get("token_usage")
or response_metadata.get("usage")
or response_metadata.get("usage_metadata")
or {}
)
if input_tokens is None:
input_tokens = cls._lookup_int(
token_usage,
"prompt_tokens",
"input_tokens",
)
if input_tokens is None:
input_tokens = cls._lookup_int(
response_metadata,
"prompt_token_count",
"input_tokens",
)
if output_tokens is None:
output_tokens = cls._lookup_int(
token_usage,
"completion_tokens",
"output_tokens",
)
if output_tokens is None:
output_tokens = cls._lookup_int(
response_metadata,
"candidates_token_count",
"output_tokens",
)
if total_tokens is None:
total_tokens = cls._lookup_int(token_usage, "total_tokens")
if total_tokens is None:
total_tokens = cls._lookup_int(response_metadata, "total_token_count")
has_usage = any(
value is not None for value in (input_tokens, output_tokens, total_tokens)
)
resolved_input = input_tokens or 0
resolved_output = output_tokens or 0
resolved_total = (
total_tokens
if total_tokens is not None
else resolved_input + resolved_output
)
return {
"has_usage": has_usage,
"input_tokens": resolved_input,
"output_tokens": resolved_output,
"total_tokens": resolved_total,
}
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT]:
response = await handler(request)
if not callable(self.on_usage):
return response
try:
ai_message = next(
(
message
for message in reversed(response.result)
if isinstance(message, AIMessage)
),
None,
)
usage = (
self._extract_usage(ai_message)
if ai_message
else {
"has_usage": False,
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
}
)
context_window_tokens = self._extract_context_window_tokens(request.model)
context_usage_ratio = None
if context_window_tokens and usage["has_usage"]:
context_usage_ratio = usage["input_tokens"] / context_window_tokens
self.on_usage(
{
"model": self._extract_model_name(request.model),
"context_window_tokens": context_window_tokens,
"context_usage_ratio": context_usage_ratio,
**usage,
}
)
except Exception as e:
logger.debug("记录模型 usage 失败: %s", e)
return response
__all__ = ["UsageMiddleware"]

View File

@@ -15,9 +15,12 @@ Core Capabilities:
<communication>
{verbose_spec}
- Tone: friendly, concise. Like a knowledgeable friend, not a corporate bot.
- Use emojis sparingly (1-3 per response): greetings, completions, errors.
- Tone: professional, concise, restrained.
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
- Prioritize task progress over conversation. Answer only what is necessary to move the task forward.
- Do NOT flatter the user, praise the question, or use overly eager/service-oriented phrases.
- Do NOT use emojis, exclamation marks, cute language, or excessive apology.
- Prefer short declarative sentences. Default to one or two short paragraphs; use lists only when they improve scanability.
- Use Markdown for structured data. Use `inline code` for media titles/paths.
- Include key details (year, rating, resolution) but do NOT over-explain.
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
@@ -34,6 +37,7 @@ Core Capabilities:
- NO filler phrases like "Let me help you", "Here are the results", "I found..." — skip all unnecessary preamble.
- NO repeating what user said.
- NO narrating your internal reasoning.
- NO praise, emotional cushioning, or unnecessary politeness padding.
- After task completion: one line summary only.
- When error occurs: brief acknowledgment + suggestion, then move on.
</response_format>
@@ -56,6 +60,7 @@ Core Capabilities:
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
5. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
</media_management_rules>
<markdown_spec>

View File

@@ -82,11 +82,13 @@ class PromptManager:
verbose_spec = ""
if not settings.AI_AGENT_VERBOSE:
verbose_spec = (
"\n\n[Important Instruction] STRICTLY ENFORCED: DO NOT output any conversational "
"text, thinking processes, or explanations before or during tool calls. Call tools "
"directly without any transitional phrases. "
"You MUST remain completely silent until the task is completely finished. "
"DO NOT output any content whatsoever until your final summary reply."
"\n\n[Important Instruction] STRICTLY ENFORCED: "
"If tools are needed, DO NOT output any conversational text, explanations, progress updates, "
"or acknowledgements before the first tool call or between tool calls. "
"Call tools directly without any transitional phrases. "
"You MUST remain completely silent until all required tools have finished and you have the final result. "
"Only then may you send one final user-facing reply. "
"DO NOT output any intermediate content whatsoever."
)
# MoviePilot系统信息
@@ -193,18 +195,18 @@ class PromptManager:
def _generate_button_choice_instructions(
channel: MessageChannel = None,
) -> str:
if channel and ChannelCapabilityManager.supports_buttons(
if (
channel
) and ChannelCapabilityManager.supports_callbacks(channel):
and ChannelCapabilityManager.supports_buttons(channel)
and ChannelCapabilityManager.supports_callbacks(channel)
):
return (
"- User questions: If you need the user to choose from a few clear options, "
"call `ask_user_choice` to send button options. After the user clicks a button, "
"the selected value will come back as the user's next message. After calling this tool, "
"wait for the user's selection instead of repeating the question in plain text."
)
return (
"- User questions: When you truly need user input, ask briefly in plain text."
)
return "- User questions: When you truly need user input, ask briefly in plain text."
def clear_cache(self):
"""

View File

@@ -82,8 +82,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message)
else:
# 非VERBOSE,重置缓冲区从头更新,保持消息编辑能力
self._stream_handler.reset()
# 非VERBOSE:工具边界至少补一个换行,避免工具前后的文本直接连在一起
if self._stream_handler.last_buffer_char not in ("", "\n"):
self._stream_handler.emit("\n")
else:
# 未启用流式传输,不发送任何工具消息内容
pass

View File

@@ -47,13 +47,13 @@ class AddDownloadTool(MoviePilotTool):
if torrent_urls:
if len(torrent_urls) == 1:
if self._is_torrent_ref(torrent_urls[0]):
message = f"正在添加下载任务: 资源 {torrent_urls[0]}"
message = f"添加下载任务: 资源 {torrent_urls[0]}"
else:
message = "正在添加下载任务: 磁力链接"
message = "添加下载任务: 磁力链接"
else:
message = f"正在批量添加下载任务: 共 {len(torrent_urls)} 个资源"
message = f"批量添加下载任务: 共 {len(torrent_urls)} 个资源"
else:
message = "正在添加下载任务"
message = "添加下载任务"
if downloader:
message += f" [下载器: {downloader}]"

View File

@@ -12,36 +12,74 @@ from app.schemas.types import MediaType
class AddSubscribeInput(BaseModel):
"""添加订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
year: str = Field(..., description="Release year of the media (required for accurate identification)")
media_type: str = Field(...,
description="Allowed values: movie, tv")
season: Optional[int] = Field(None,
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
tmdb_id: Optional[int] = Field(None,
description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)")
douban_id: Optional[str] = Field(None,
description="Douban ID for precise media identification (optional, alternative to tmdb_id)")
start_episode: Optional[int] = Field(None,
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)")
total_episode: Optional[int] = Field(None,
description="Total number of episodes for TV shows (optional, will be auto-detected from TMDB if not specified)")
quality: Optional[str] = Field(None,
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
resolution: Optional[str] = Field(None,
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
effect: Optional[str] = Field(None,
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
filter_groups: Optional[List[str]] = Field(None,
description="List of filter rule group names to apply (optional, can be obtained from query_rule_groups tool)")
sites: Optional[List[int]] = Field(None,
description="List of site IDs to search from (optional, can be obtained from query_sites tool)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
title: str = Field(
...,
description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')",
)
year: str = Field(
...,
description="Release year of the media (required for accurate identification)",
)
media_type: str = Field(..., description="Allowed values: movie, tv")
season: Optional[int] = Field(
None,
description=(
"Season number for TV shows (optional). If omitted, the subscription defaults to season 1 only. "
"To subscribe multiple seasons or the full series, call this tool separately for each season."
),
)
tmdb_id: Optional[int] = Field(
None,
description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)",
)
douban_id: Optional[str] = Field(
None,
description="Douban ID for precise media identification (optional, alternative to tmdb_id)",
)
start_episode: Optional[int] = Field(
None,
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)",
)
total_episode: Optional[int] = Field(
None,
description="Total number of episodes for TV shows (optional, will be auto-detected from TMDB if not specified)",
)
quality: Optional[str] = Field(
None,
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')",
)
resolution: Optional[str] = Field(
None,
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')",
)
effect: Optional[str] = Field(
None,
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')",
)
filter_groups: Optional[List[str]] = Field(
None,
description="List of filter rule group names to apply (optional, can be obtained from query_rule_groups tool)",
)
sites: Optional[List[int]] = Field(
None,
description="List of site IDs to search from (optional, can be obtained from query_sites tool)",
)
class AddSubscribeTool(MoviePilotTool):
name: str = "add_subscribe"
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria. Supports advanced filtering options like quality, resolution, and effect filters using regular expressions."
description: str = (
"Add media subscription to create automated download rules for movies and TV shows. "
"The system will automatically search and download new episodes or releases based on the subscription criteria. "
"For TV shows, omitting `season` subscribes season 1 only by default; to subscribe multiple seasons or "
"the full series, call this tool once per season. Supports advanced filtering options like quality, "
"resolution, and effect filters using regular expressions."
)
args_schema: Type[BaseModel] = AddSubscribeInput
def get_tool_message(self, **kwargs) -> Optional[str]:
@@ -50,52 +88,72 @@ class AddSubscribeTool(MoviePilotTool):
year = kwargs.get("year", "")
media_type = kwargs.get("media_type", "")
season = kwargs.get("season")
message = f"正在添加订阅: {title}"
message = f"添加订阅: {title}"
if year:
message += f" ({year})"
if media_type:
message += f" [{media_type}]"
if season:
message += f"{season}"
elif media_type == "tv":
message += " 第1季(默认)"
return message
async def run(self, title: str, year: str, media_type: str,
season: Optional[int] = None, tmdb_id: Optional[int] = None,
douban_id: Optional[str] = None,
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
quality: Optional[str] = None, resolution: Optional[str] = None,
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
sites: Optional[List[int]] = None, **kwargs) -> str:
async def run(
self,
title: str,
year: str,
media_type: str,
season: Optional[int] = None,
tmdb_id: Optional[int] = None,
douban_id: Optional[str] = None,
start_episode: Optional[int] = None,
total_episode: Optional[int] = None,
quality: Optional[str] = None,
resolution: Optional[str] = None,
effect: Optional[str] = None,
filter_groups: Optional[List[str]] = None,
sites: Optional[List[int]] = None,
**kwargs,
) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
f"season={season}, tmdb_id={tmdb_id}, douban_id={douban_id}, start_episode={start_episode}, "
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
f"effect={effect}, filter_groups={filter_groups}, sites={sites}"
)
try:
subscribe_chain = SubscribeChain()
media_type_enum = MediaType.from_agent(media_type)
if not media_type_enum:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
effective_season = (
season
if season is not None
else 1
if media_type_enum == MediaType.TV
else None
)
# 构建额外的订阅参数
subscribe_kwargs = {}
if start_episode is not None:
subscribe_kwargs['start_episode'] = start_episode
subscribe_kwargs["start_episode"] = start_episode
if total_episode is not None:
subscribe_kwargs['total_episode'] = total_episode
subscribe_kwargs["total_episode"] = total_episode
if quality:
subscribe_kwargs['quality'] = quality
subscribe_kwargs["quality"] = quality
if resolution:
subscribe_kwargs['resolution'] = resolution
subscribe_kwargs["resolution"] = resolution
if effect:
subscribe_kwargs['effect'] = effect
subscribe_kwargs["effect"] = effect
if filter_groups:
subscribe_kwargs['filter_groups'] = filter_groups
subscribe_kwargs["filter_groups"] = filter_groups
if sites:
subscribe_kwargs['sites'] = sites
subscribe_kwargs["sites"] = sites
sid, message = await subscribe_chain.async_add(
mtype=media_type_enum,
@@ -105,13 +163,21 @@ class AddSubscribeTool(MoviePilotTool):
doubanid=douban_id,
season=season,
username=self._user_id,
**subscribe_kwargs
**subscribe_kwargs,
)
if sid:
if message and "已存在" in message:
return f"订阅已存在:{title} ({year})。如需修改参数请先删除旧订阅。"
result_msg = f"订阅已存在:{title} ({year})"
if effective_season is not None:
result_msg += f"{effective_season}"
result_msg += "。如需修改参数请先删除旧订阅。"
return result_msg
result_msg = f"成功添加订阅:{title} ({year})"
if effective_season is not None:
result_msg += f"{effective_season}"
if season is None:
result_msg += "(未指定季号,默认按第一季订阅)"
if subscribe_kwargs:
params = []
if start_episode is not None:

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.agent.interaction import (
from app.chain.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)
@@ -75,7 +75,7 @@ class AskUserChoiceTool(MoviePilotTool):
message = kwargs.get("message", "") or ""
if len(message) > 40:
message = message[:40] + "..."
return f"正在发送按钮选择: {message}"
return f"发送按钮选择: {message}"
@staticmethod
def _truncate_button_text(text: str, max_length: int) -> str:
@@ -106,7 +106,7 @@ class AskUserChoiceTool(MoviePilotTool):
):
return f"当前渠道 {channel.value} 不支持按钮选择"
max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
max_per_row = 1
max_rows = ChannelCapabilityManager.get_max_button_rows(channel)
max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
max_options = max_per_row * max_rows

View File

@@ -108,16 +108,16 @@ class BrowseWebpageTool(MoviePilotTool):
url = kwargs.get("url", "")
selector = kwargs.get("selector", "")
action_messages = {
"goto": f"正在打开网页: {url}",
"get_content": "正在获取页面内容",
"screenshot": "正在截取页面截图",
"click": f"正在点击元素: {selector}",
"fill": f"正在填写表单: {selector}",
"select": f"正在选择选项: {selector}",
"evaluate": "正在执行 JavaScript",
"wait": f"正在等待元素: {selector}",
"goto": f"打开网页: {url}",
"get_content": "获取页面内容",
"screenshot": "截取页面截图",
"click": f"点击元素: {selector}",
"fill": f"填写表单: {selector}",
"select": f"选择选项: {selector}",
"evaluate": "执行 JavaScript",
"wait": f"等待元素: {selector}",
}
return action_messages.get(action, f"正在执行浏览器操作: {action}")
return action_messages.get(action, f"执行浏览器操作: {action}")
async def run(
self,

View File

@@ -41,7 +41,7 @@ class DeleteDownloadTool(MoviePilotTool):
downloader = kwargs.get("downloader")
delete_files = kwargs.get("delete_files", False)
message = f"正在删除下载任务: {hash_value}"
message = f"删除下载任务: {hash_value}"
if downloader:
message += f" [下载器: {downloader}]"
if delete_files:

View File

@@ -30,7 +30,7 @@ class DeleteDownloadHistoryTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
history_id = kwargs.get("history_id")
return f"正在删除下载历史记录 ID: {history_id}"
return f"删除下载历史记录 ID: {history_id}"
async def run(self, history_id: int, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")

View File

@@ -34,7 +34,7 @@ class DeleteSubscribeTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据删除参数生成友好的提示消息"""
subscribe_id = kwargs.get("subscribe_id")
return f"正在删除订阅 (ID: {subscribe_id})"
return f"删除订阅 (ID: {subscribe_id})"
async def run(self, subscribe_id: int, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")

View File

@@ -30,7 +30,7 @@ class DeleteTransferHistoryTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据参数生成友好的提示消息"""
history_id = kwargs.get("history_id")
return f"正在删除整理历史记录: ID={history_id}"
return f"删除整理历史记录: ID={history_id}"
async def run(self, history_id: int, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")

View File

@@ -28,7 +28,7 @@ class EditFileTool(MoviePilotTool):
"""根据参数生成友好的提示消息"""
file_path = kwargs.get("file_path", "")
file_name = Path(file_path).name if file_path else "未知文件"
return f"正在编辑文件: {file_name}"
return f"编辑文件: {file_name}"
async def run(self, file_path: str, old_text: str, new_text: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")

View File

@@ -1,6 +1,10 @@
"""执行Shell命令工具"""
import asyncio
import os
import signal
import subprocess
from dataclasses import dataclass, field
from typing import Optional, Type
from pydantic import BaseModel, Field
@@ -9,6 +13,54 @@ from app.agent.tools.base import MoviePilotTool
from app.log import logger
DEFAULT_TIMEOUT_SECONDS = 60
MAX_TIMEOUT_SECONDS = 300
MAX_OUTPUT_CHARS = 6000
READ_CHUNK_SIZE = 4096
KILL_GRACE_SECONDS = 3
COMMAND_CONCURRENCY_LIMIT = 2
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
@dataclass
class _CommandOutput:
"""保存受限命令输出,避免大输出一次性进入内存。"""
limit: int
stdout_chunks: list[str] = field(default_factory=list)
stderr_chunks: list[str] = field(default_factory=list)
captured_chars: int = 0
truncated: bool = False
def append(self, stream_name: str, text: str) -> None:
if not text:
return
remaining = self.limit - self.captured_chars
if remaining <= 0:
self.truncated = True
return
captured = text[:remaining]
if stream_name == "stdout":
self.stdout_chunks.append(captured)
else:
self.stderr_chunks.append(captured)
self.captured_chars += len(captured)
if len(text) > remaining:
self.truncated = True
@property
def stdout(self) -> str:
return "".join(self.stdout_chunks).strip()
@property
def stderr(self) -> str:
return "".join(self.stderr_chunks).strip()
class ExecuteCommandInput(BaseModel):
"""执行Shell命令工具的输入参数模型"""
@@ -23,14 +75,160 @@ class ExecuteCommandInput(BaseModel):
class ExecuteCommandTool(MoviePilotTool):
name: str = "execute_command"
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
description: str = (
"Safely execute shell commands on the server. Useful for system "
"maintenance, checking status, or running custom scripts. Includes "
"timeout, concurrency, and hard output limits."
)
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"正在执行系统命令: {command}"
return f"执行系统命令: {command}"
@staticmethod
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
"""限制命令最长运行时间,避免 Agent 传入过大的 timeout。"""
try:
normalized = int(timeout or DEFAULT_TIMEOUT_SECONDS)
except (TypeError, ValueError):
normalized = DEFAULT_TIMEOUT_SECONDS
if normalized <= 0:
return DEFAULT_TIMEOUT_SECONDS, "timeout 参数无效,已使用默认 60 秒"
if normalized > MAX_TIMEOUT_SECONDS:
return (
MAX_TIMEOUT_SECONDS,
f"timeout 参数超过上限,已从 {normalized} 秒限制为 {MAX_TIMEOUT_SECONDS}",
)
return normalized, None
@staticmethod
def _subprocess_kwargs() -> dict:
"""为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。"""
kwargs = {
"stdin": subprocess.DEVNULL,
"stdout": asyncio.subprocess.PIPE,
"stderr": asyncio.subprocess.PIPE,
}
if os.name == "posix":
kwargs["start_new_session"] = True
elif os.name == "nt":
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
return kwargs
@staticmethod
async def _read_stream(
stream: asyncio.StreamReader,
stream_name: str,
output: _CommandOutput,
limit_reached: asyncio.Event,
) -> None:
"""按块读取输出,达到上限后通知主流程终止命令。"""
while True:
chunk = await stream.read(READ_CHUNK_SIZE)
if not chunk:
break
if output.truncated:
limit_reached.set()
continue
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
if output.truncated:
limit_reached.set()
# 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。
continue
@staticmethod
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
"""向进程组发送终止信号;不支持进程组的平台回退为单进程终止。"""
try:
if os.name == "posix":
os.killpg(process.pid, sig)
elif sig == getattr(signal, "SIGKILL", None):
process.kill()
else:
process.terminate()
except ProcessLookupError:
pass
@classmethod
async def _cleanup_process(
cls,
process: asyncio.subprocess.Process,
wait_task: asyncio.Task,
) -> None:
"""先温和终止,失败后强杀,避免超时 shell 遗留子进程。"""
if wait_task.done():
return
cls._terminate_process(process, signal.SIGTERM)
try:
await asyncio.wait_for(
asyncio.shield(wait_task), timeout=KILL_GRACE_SECONDS
)
return
except asyncio.TimeoutError:
pass
kill_signal = getattr(signal, "SIGKILL", signal.SIGTERM)
cls._terminate_process(process, kill_signal)
try:
await asyncio.wait_for(
asyncio.shield(wait_task), timeout=KILL_GRACE_SECONDS
)
except asyncio.TimeoutError:
logger.warning("命令进程强制清理超时: pid=%s", process.pid)
@staticmethod
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
"""等待输出读取任务退出,异常只记录不影响工具返回。"""
if not reader_tasks:
return
done, pending = await asyncio.wait(reader_tasks, timeout=1)
for task in pending:
task.cancel()
results = await asyncio.gather(*done, *pending, return_exceptions=True)
for result in results:
if isinstance(result, Exception) and not isinstance(
result, asyncio.CancelledError
):
logger.debug("命令输出读取任务异常: %s", result)
@staticmethod
def _format_result(
*,
exit_code: Optional[int],
output: _CommandOutput,
timeout: int,
timed_out: bool,
output_limited: bool,
timeout_note: Optional[str],
) -> str:
if timed_out:
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
elif output_limited:
result = (
f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符,"
f"已截断并终止进程,退出码: {exit_code})"
)
else:
result = f"命令执行完成 (退出码: {exit_code})"
if timeout_note:
result += f"\n\n提示:\n{timeout_note}"
if output.stdout:
result += f"\n\n标准输出:\n{output.stdout}"
if output.stderr:
result += f"\n\n错误输出:\n{output.stderr}"
if output.truncated:
result += "\n\n...(输出内容过长,已截断)"
if not output.stdout and not output.stderr:
result += "\n\n(无输出内容)"
return result
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(
@@ -50,46 +248,57 @@ class ExecuteCommandTool(MoviePilotTool):
if keyword in command:
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
try:
# 执行命令
process = await asyncio.create_subprocess_shell(
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
try:
# 等待完成,带超时
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=timeout
try:
async with _command_semaphore:
# 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。
process = await asyncio.create_subprocess_shell(
command, **self._subprocess_kwargs()
)
output = _CommandOutput(limit=MAX_OUTPUT_CHARS)
limit_reached = asyncio.Event()
wait_task = asyncio.create_task(process.wait())
limit_task = asyncio.create_task(limit_reached.wait())
reader_tasks = [
asyncio.create_task(
self._read_stream(
process.stdout, "stdout", output, limit_reached
)
),
asyncio.create_task(
self._read_stream(
process.stderr, "stderr", output, limit_reached
)
),
]
timed_out = False
output_limited = False
done, _ = await asyncio.wait(
{wait_task, limit_task},
timeout=normalized_timeout,
return_when=asyncio.FIRST_COMPLETED,
)
# 处理输出
stdout_str = stdout.decode("utf-8", errors="replace").strip()
stderr_str = stderr.decode("utf-8", errors="replace").strip()
exit_code = process.returncode
if wait_task not in done:
if limit_task in done:
output_limited = True
else:
timed_out = True
await self._cleanup_process(process, wait_task)
result = f"命令执行完成 (退出码: {exit_code})"
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
limit_task.cancel()
await self._finish_reader_tasks(reader_tasks)
# 如果没有输出
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
# 限制输出长度,防止上下文过长
if len(result) > 3000:
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
return result
except asyncio.TimeoutError:
# 超时处理
try:
process.kill()
except ProcessLookupError:
pass
return f"命令执行超时 (限制: {timeout}秒)"
return self._format_result(
exit_code=process.returncode,
output=output,
timeout=normalized_timeout,
timed_out=timed_out,
output_limited=output_limited,
timeout_note=timeout_note,
)
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)

View File

@@ -62,7 +62,7 @@ class GetRecommendationsTool(MoviePilotTool):
"douban_hot": "豆瓣热门",
"douban_movie_hot": "豆瓣热门电影",
"douban_tv_hot": "豆瓣热门电视剧",
"douban_movie_showing": "豆瓣正在热映",
"douban_movie_showing": "豆瓣热映",
"douban_movies": "豆瓣最新电影",
"douban_tvs": "豆瓣最新电视剧",
"douban_movie_top250": "豆瓣电影TOP250",
@@ -73,7 +73,7 @@ class GetRecommendationsTool(MoviePilotTool):
}
source_desc = source_map.get(source, source)
message = f"正在获取推荐: {source_desc}"
message = f"获取推荐: {source_desc}"
if media_type != "all":
message += f" [{media_type}]"
message += f" (第{page}页)"

View File

@@ -53,7 +53,7 @@ class GetSearchResultsTool(MoviePilotTool):
args_schema: Type[BaseModel] = GetSearchResultsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
return "正在获取搜索结果"
return "获取搜索结果"
async def run(
self,

View File

@@ -32,7 +32,7 @@ class ListDirectoryTool(MoviePilotTool):
path = kwargs.get("path", "")
storage = kwargs.get("storage", "local")
message = f"正在查询目录: {path}"
message = f"查询目录: {path}"
if storage != "local":
message += f" [存储: {storage}]"

View File

@@ -33,7 +33,7 @@ class ListSlashCommandsTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
return "正在查询所有可用命令"
return "查询所有可用命令"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")

View File

@@ -55,7 +55,7 @@ class ModifyDownloadTool(MoviePilotTool):
tags = kwargs.get("tags")
downloader = kwargs.get("downloader")
parts = [f"正在修改下载任务: {hash_value}"]
parts = [f"修改下载任务: {hash_value}"]
if action == "start":
parts.append("操作: 开始下载")
elif action == "stop":

View File

@@ -31,7 +31,7 @@ class QueryCustomIdentifiersTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
return "正在查询自定义识别词"
return "查询自定义识别词"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")

View File

@@ -32,7 +32,7 @@ class QueryDirectorySettingsTool(MoviePilotTool):
storage_type = kwargs.get("storage_type", "all")
name = kwargs.get("name")
parts = ["正在查询目录配置"]
parts = ["查询目录配置"]
if directory_type != "all":
type_map = {"download": "下载目录", "library": "媒体库目录"}

View File

@@ -36,7 +36,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
查询所有状态的任务(包括下载中和已完成的任务)
"""
all_torrents = []
# 查询正在下载的任务
# 查询下载的任务
downloading_torrents = download_chain.list_torrents(
downloader=downloader,
status=TorrentStatus.DOWNLOADING
@@ -71,7 +71,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
hash_value = kwargs.get("hash")
title = kwargs.get("title")
parts = ["正在查询下载任务"]
parts = ["查询下载任务"]
if downloader:
parts.append(f"下载器: {downloader}")

View File

@@ -23,7 +23,7 @@ class QueryDownloadersTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
return "正在查询下载器配置"
return "查询下载器配置"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")

View File

@@ -29,7 +29,7 @@ class QueryEpisodeScheduleTool(MoviePilotTool):
season = kwargs.get("season")
episode_group = kwargs.get("episode_group")
message = f"正在查询剧集上映时间: TMDB ID {tmdb_id}{season}"
message = f"查询剧集上映时间: TMDB ID {tmdb_id}{season}"
if episode_group:
message += f" (剧集组: {episode_group})"

View File

@@ -31,7 +31,7 @@ class QueryInstalledPluginsTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
return "正在查询已安装插件"
return "查询已安装插件"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")

View File

@@ -93,11 +93,11 @@ class QueryLibraryExistsTool(MoviePilotTool):
media_type = kwargs.get("media_type")
if tmdb_id:
message = f"正在查询媒体库: TMDB={tmdb_id}"
message = f"查询媒体库: TMDB={tmdb_id}"
elif douban_id:
message = f"正在查询媒体库: 豆瓣={douban_id}"
message = f"查询媒体库: 豆瓣={douban_id}"
else:
message = "正在查询媒体库"
message = "查询媒体库"
if media_type:
message += f" [{media_type}]"
return message

View File

@@ -39,7 +39,7 @@ class QueryLibraryLatestTool(MoviePilotTool):
server = kwargs.get("server")
page = kwargs.get("page", 1)
parts = ["正在查询媒体服务器最近入库影片"]
parts = ["查询媒体服务器最近入库影片"]
if server:
parts.append(f"服务器: {server}")

View File

@@ -29,8 +29,8 @@ class QueryMediaDetailTool(MoviePilotTool):
tmdb_id = kwargs.get("tmdb_id")
douban_id = kwargs.get("douban_id")
if tmdb_id:
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
return f"正在查询媒体详情: 豆瓣 ID {douban_id}"
return f"查询媒体详情: TMDB ID {tmdb_id}"
return f"查询媒体详情: 豆瓣 ID {douban_id}"
async def run(self, media_type: str, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}")

View File

@@ -40,8 +40,8 @@ class QueryPluginCapabilitiesTool(MoviePilotTool):
"""生成友好的提示消息"""
plugin_id = kwargs.get("plugin_id")
if plugin_id:
return f"正在查询插件 {plugin_id} 的能力"
return "正在查询所有插件的能力"
return f"查询插件 {plugin_id} 的能力"
return "查询所有插件的能力"
async def run(self, plugin_id: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")

View File

@@ -39,7 +39,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
min_rating = kwargs.get("min_rating")
max_rating = kwargs.get("max_rating")
parts = [f"正在查询热门订阅 [{media_type}]"]
parts = [f"查询热门订阅 [{media_type}]"]
if min_sub:
parts.append(f"最少订阅: {min_sub}")

View File

@@ -22,7 +22,7 @@ class QueryRuleGroupsTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
return "正在查询所有规则组"
return "查询所有规则组"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")

View File

@@ -22,7 +22,7 @@ class QuerySchedulersTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
return "正在查询定时服务"
return "查询定时服务"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")

View File

@@ -40,7 +40,7 @@ class QuerySiteUserdataTool(MoviePilotTool):
site_id = kwargs.get("site_id")
workdate = kwargs.get("workdate")
message = f"正在查询站点 #{site_id} 的用户数据"
message = f"查询站点 #{site_id} 的用户数据"
if workdate:
message += f" (日期: {workdate})"
else:

View File

@@ -37,7 +37,7 @@ class QuerySitesTool(MoviePilotTool):
status = kwargs.get("status", "all")
name = kwargs.get("name")
parts = ["正在查询站点"]
parts = ["查询站点"]
if status != "all":
status_map = {"active": "已启用", "inactive": "已禁用"}

View File

@@ -44,7 +44,7 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
name = kwargs.get("name")
page = kwargs.get("page", 1)
parts = ["正在查询订阅历史"]
parts = ["查询订阅历史"]
if media_type != "all":
parts.append(f"类型: {media_type}")

View File

@@ -34,7 +34,7 @@ class QuerySubscribeSharesTool(MoviePilotTool):
min_rating = kwargs.get("min_rating")
max_rating = kwargs.get("max_rating")
parts = ["正在查询订阅分享"]
parts = ["查询订阅分享"]
if name:
parts.append(f"名称: {name}")

View File

@@ -79,7 +79,7 @@ class QuerySubscribesTool(MoviePilotTool):
media_type = kwargs.get("media_type", "all")
page = kwargs.get("page", 1)
parts = ["正在查询订阅"]
parts = ["查询订阅"]
# 根据状态过滤条件生成提示
if status != "all":

View File

@@ -33,7 +33,7 @@ class QueryTransferHistoryTool(MoviePilotTool):
status = kwargs.get("status", "all")
page = kwargs.get("page", 1)
parts = ["正在查询整理历史"]
parts = ["查询整理历史"]
if title:
parts.append(f"标题: {title}")

View File

@@ -30,7 +30,7 @@ class QueryWorkflowsTool(MoviePilotTool):
name = kwargs.get("name")
trigger_type = kwargs.get("trigger_type", "all")
parts = ["正在查询工作流"]
parts = ["查询工作流"]
if state != "all":
state_map = {"W": "等待", "R": "运行中", "P": "暂停", "S": "成功", "F": "失败"}

View File

@@ -29,7 +29,7 @@ class ReadFileTool(MoviePilotTool):
"""根据参数生成友好的提示消息"""
file_path = kwargs.get("file_path", "")
file_name = Path(file_path).name if file_path else "未知文件"
return f"正在读取文件: {file_name}"
return f"读取文件: {file_name}"
async def run(self, file_path: str, start_line: Optional[int] = None,
end_line: Optional[int] = None, **kwargs) -> str:

View File

@@ -33,13 +33,13 @@ class RecognizeMediaTool(MoviePilotTool):
path = kwargs.get("path")
if path:
message = f"正在识别文件媒体信息: {path}"
message = f"识别文件媒体信息: {path}"
elif title:
message = f"正在识别种子媒体信息: {title}"
message = f"识别种子媒体信息: {title}"
if subtitle:
message += f" ({subtitle})"
else:
message = "正在识别媒体信息"
message = "识别媒体信息"
return message

View File

@@ -31,7 +31,7 @@ class RunSchedulerTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据运行参数生成友好的提示消息"""
job_id = kwargs.get("job_id", "")
return f"正在运行定时服务 (ID: {job_id})"
return f"运行定时服务 (ID: {job_id})"
async def run(self, job_id: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: job_id={job_id}")

View File

@@ -45,7 +45,7 @@ class RunSlashCommandTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
command = kwargs.get("command", "")
return f"正在执行命令: {command}"
return f"执行命令: {command}"
async def run(self, command: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: command={command}")

View File

@@ -38,7 +38,7 @@ class RunWorkflowTool(MoviePilotTool):
workflow_id = kwargs.get("workflow_id")
from_begin = kwargs.get("from_begin", True)
message = f"正在执行工作流: {workflow_id}"
message = f"执行工作流: {workflow_id}"
if not from_begin:
message += " (从上次位置继续)"
else:

View File

@@ -47,7 +47,7 @@ class ScrapeMetadataTool(MoviePilotTool):
storage = kwargs.get("storage", "local")
overwrite = kwargs.get("overwrite", False)
message = f"正在刮削媒体元数据: {path}"
message = f"刮削媒体元数据: {path}"
if storage != "local":
message += f" [存储: {storage}]"
if overwrite:

View File

@@ -34,7 +34,7 @@ class SearchMediaTool(MoviePilotTool):
media_type = kwargs.get("media_type")
season = kwargs.get("season")
message = f"正在搜索媒体: {title}"
message = f"搜索媒体: {title}"
if year:
message += f" ({year})"
if media_type:

View File

@@ -24,7 +24,7 @@ class SearchPersonTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据搜索参数生成友好的提示消息"""
name = kwargs.get("name", "")
return f"正在搜索人物: {name}"
return f"搜索人物: {name}"
async def run(self, name: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: name={name}")

View File

@@ -29,7 +29,7 @@ class SearchPersonCreditsTool(MoviePilotTool):
"""根据搜索参数生成友好的提示消息"""
person_id = kwargs.get("person_id", "")
source = kwargs.get("source", "")
return f"正在搜索人物参演作品: {source} ID {person_id}"
return f"搜索人物参演作品: {source} ID {person_id}"
async def run(self, person_id: int, source: str, page: Optional[int] = 1, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: person_id={person_id}, source={source}, page={page}")

View File

@@ -32,7 +32,7 @@ class SearchSubscribeTool(MoviePilotTool):
subscribe_id = kwargs.get("subscribe_id")
manual = kwargs.get("manual", False)
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
message = f"搜索订阅 #{subscribe_id} 的缺失剧集"
if manual:
message += "(手动搜索)"

View File

@@ -41,11 +41,11 @@ class SearchTorrentsTool(MoviePilotTool):
media_type = kwargs.get("media_type")
if tmdb_id:
message = f"正在搜索种子: TMDB={tmdb_id}"
message = f"搜索种子: TMDB={tmdb_id}"
elif douban_id:
message = f"正在搜索种子: 豆瓣={douban_id}"
message = f"搜索种子: 豆瓣={douban_id}"
else:
message = "正在搜索种子"
message = "搜索种子"
if media_type:
message += f" [{media_type}]"
return message

View File

@@ -41,7 +41,7 @@ class SearchWebTool(MoviePilotTool):
"""根据搜索参数生成友好的提示消息"""
query = kwargs.get("query", "")
max_results = kwargs.get("max_results", 20)
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
return f"搜索网络内容: {query} (最多返回 {max_results} 条结果)"
async def run(self, query: str, max_results: Optional[int] = 20, **kwargs) -> str:
"""

View File

@@ -55,7 +55,7 @@ class SendLocalFileTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
file_path = kwargs.get("file_path", "")
file_name = Path(file_path).name if file_path else "未知文件"
return f"正在发送本地附件: {file_name}"
return f"发送本地附件: {file_name}"
async def run(
self,

View File

@@ -52,12 +52,12 @@ class SendMessageTool(MoviePilotTool):
message = message[:50] + "..."
if title and image_url:
return f"正在发送图文消息: [{title}] {message}"
return f"发送图文消息: [{title}] {message}"
if title:
return f"正在发送消息: [{title}] {message}"
return f"发送消息: [{title}] {message}"
if image_url:
return f"正在发送图片消息: {message}"
return f"正在发送消息: {message}"
return f"发送图片消息: {message}"
return f"发送消息: {message}"
async def run(
self,

View File

@@ -41,7 +41,7 @@ class SendVoiceMessageTool(MoviePilotTool):
message = kwargs.get("message") or ""
if len(message) > 40:
message = message[:40] + "..."
return f"正在发送语音回复: {message}"
return f"发送语音回复: {message}"
def _supports_real_voice_reply(self) -> bool:
channel = self._channel or ""

View File

@@ -24,7 +24,7 @@ class TestSiteTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据测试参数生成友好的提示消息"""
site_identifier = kwargs.get("site_identifier")
return f"正在测试站点连通性: {site_identifier}"
return f"测试站点连通性: {site_identifier}"
async def run(self, site_identifier: int, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}")

View File

@@ -68,7 +68,7 @@ class TransferFileTool(MoviePilotTool):
transfer_type = kwargs.get("transfer_type")
background = kwargs.get("background", False)
message = f"正在整理文件: {file_path}"
message = f"整理文件: {file_path}"
if media_type:
message += f" [{media_type}]"
if transfer_type:

View File

@@ -57,7 +57,7 @@ class UpdateCustomIdentifiersTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
identifiers = kwargs.get("identifiers", [])
return f"正在更新自定义识别词(共 {len(identifiers)} 条规则)"
return f"更新自定义识别词(共 {len(identifiers)} 条规则)"
async def run(self, identifiers: List[str] = None, **kwargs) -> str:
logger.info(

View File

@@ -95,8 +95,8 @@ class UpdateSiteTool(MoviePilotTool):
fields_updated.append("下载器")
if fields_updated:
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
return f"正在更新站点 #{site_id}"
return f"更新站点 #{site_id}: {', '.join(fields_updated)}"
return f"更新站点 #{site_id}"
async def run(
self,

View File

@@ -41,7 +41,7 @@ class UpdateSiteCookieTool(MoviePilotTool):
username = kwargs.get("username", "")
two_step_code = kwargs.get("two_step_code")
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
message = f"更新站点Cookie: {site_identifier} (用户: {username})"
if two_step_code:
message += " [需要两步验证]"

View File

@@ -117,8 +117,8 @@ class UpdateSubscribeTool(MoviePilotTool):
fields_updated.append("下载器")
if fields_updated:
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
return f"正在更新订阅 #{subscribe_id}"
return f"更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
return f"更新订阅 #{subscribe_id}"
async def run(
self,

View File

@@ -27,7 +27,7 @@ class WriteFileTool(MoviePilotTool):
"""根据参数生成友好的提示消息"""
file_path = kwargs.get("file_path", "")
file_name = Path(file_path).name if file_path else "未知文件"
return f"正在写入文件: {file_name}"
return f"写入文件: {file_name}"
async def run(self, file_path: str, content: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
@@ -30,3 +30,5 @@ api_router.include_router(recommend.router, prefix="/recommend", tags=["recommen
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
api_router.include_router(openai.router, prefix="/openai/v1", tags=["openai"])
api_router.include_router(anthropic.router, prefix="/anthropic/v1", tags=["anthropic"])

View File

@@ -0,0 +1,158 @@
import asyncio
import json
import time
import uuid
from typing import AsyncIterator, List, Optional
from fastapi import APIRouter, Header, Security
from fastapi.responses import JSONResponse, StreamingResponse
from app import schemas
from app.api.endpoints.openai import (
MODEL_ID,
_CollectingMoviePilotAgent,
_error_response as _openai_error_response,
)
from app.api.openai_utils import build_anthropic_messages, build_prompt, build_session_id
from app.core.config import settings
from app.core.security import anthropic_api_key_header
from app.schemas.types import MessageChannel
router = APIRouter()
SESSION_PREFIX = "anthropic:"
def _anthropic_error_response(
message: str,
status_code: int,
error_type: str = "invalid_request_error",
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content=schemas.AnthropicErrorResponse(
error=schemas.AnthropicErrorDetail(type=error_type, message=message)
).model_dump(),
)
def _check_auth(api_key: Optional[str]) -> Optional[JSONResponse]:
if not api_key or api_key != settings.API_TOKEN:
return _anthropic_error_response(
"invalid x-api-key",
401,
error_type="authentication_error",
)
return None
async def _stream_anthropic_response(
agent: _CollectingMoviePilotAgent,
prompt: str,
images: List[str],
) -> AsyncIterator[str]:
event_queue: asyncio.Queue = asyncio.Queue()
if hasattr(agent.stream_handler, "bind_queue"):
agent.stream_handler.bind_queue(event_queue)
message_id = f"msg_{uuid.uuid4().hex}"
async def _run_agent():
try:
await agent.process(prompt, images=images, files=None)
except Exception as exc:
await event_queue.put({"error": str(exc)})
finally:
await event_queue.put(None)
task = asyncio.create_task(_run_agent())
try:
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': MODEL_ID, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}}, ensure_ascii=False)}\n\n"
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}, ensure_ascii=False)}\n\n"
while True:
item = await event_queue.get()
if item is None:
break
if isinstance(item, dict) and item.get("error"):
raise RuntimeError(str(item["error"]))
text = str(item or "")
if not text:
continue
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': text}}, ensure_ascii=False)}\n\n"
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0}, ensure_ascii=False)}\n\n"
yield f"event: message_delta\ndata: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 0}}, ensure_ascii=False)}\n\n"
yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'}, ensure_ascii=False)}\n\n"
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@router.post("/messages", summary="Anthropic compatible messages", response_model=schemas.AnthropicMessagesResponse)
async def messages(
payload: schemas.AnthropicMessagesRequest,
x_api_key: Optional[str] = Security(anthropic_api_key_header),
anthropic_version: Optional[str] = Header(default=None, alias="anthropic-version"),
):
auth_error = _check_auth(x_api_key)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _anthropic_error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="api_error",
)
normalized_messages = build_anthropic_messages(payload.system, payload.messages)
try:
prompt, images = build_prompt(normalized_messages, use_server_session=False)
except ValueError as exc:
return _anthropic_error_response(str(exc), 400)
session_seed = anthropic_version or "anthropic"
session_id = build_session_id(f"{session_seed}:{uuid.uuid4().hex}", SESSION_PREFIX)
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_id,
channel=MessageChannel.Web.value,
source="anthropic",
username="anthropic-client",
stream_mode=payload.stream,
)
if payload.stream:
return StreamingResponse(
_stream_anthropic_response(agent=agent, prompt=prompt, images=images),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _anthropic_error_response(str(exc), 500, error_type="api_error")
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
return schemas.AnthropicMessagesResponse(
id=f"msg_{uuid.uuid4().hex}",
content=[schemas.AnthropicTextBlock(text=content)],
model=MODEL_ID,
)

426
app/api/endpoints/openai.py Normal file
View File

@@ -0,0 +1,426 @@
import asyncio
import json
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple
from fastapi import APIRouter, Request, Security
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials
from app import schemas
from app.api.openai_utils import (
build_completion_payload,
build_prompt,
build_responses_input,
build_session_id,
)
from app.agent import MoviePilotAgent, StreamingHandler
from app.core.config import settings
from app.core.security import openai_bearer_scheme
from app.schemas.types import MessageChannel
router = APIRouter()
MODEL_ID = "moviepilot-agent"
SESSION_PREFIX = "openai:"
class _CollectingMoviePilotAgent(MoviePilotAgent):
"""
捕获 Agent 最终输出,避免再通过消息渠道二次发送。
"""
def __init__(self, *args, stream_mode: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.collected_messages: List[str] = []
self.stream_mode = stream_mode
if stream_mode:
self.stream_handler = _OpenAIStreamingHandler()
def _should_stream(self) -> bool:
return self.stream_mode
async def send_agent_message(self, message: str, title: str = ""):
text = (message or "").strip()
if title and text:
text = f"{title}\n{text}"
elif title:
text = title.strip()
if text:
self.collected_messages.append(text)
if self.stream_mode:
self.stream_handler.emit(text)
async def _save_agent_message_to_db(self, message: str, title: str = ""):
return None
class _OpenAIStreamingHandler(StreamingHandler):
"""
将 Agent 流式输出转发到 OpenAI SSE 队列,不向站内消息系统落消息。
"""
def __init__(self):
super().__init__()
self._event_queue: Optional[asyncio.Queue] = None
def bind_queue(self, queue: asyncio.Queue):
self._event_queue = queue
def emit(self, token: str):
super().emit(token)
if token and self._event_queue is not None:
self._event_queue.put_nowait(token)
async def start_streaming(
self,
channel: Optional[str] = None,
source: Optional[str] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
title: str = "",
):
self._channel = channel
self._source = source
self._user_id = user_id
self._username = username
self._title = title
self._streaming_enabled = True
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._max_message_length = 0
async def stop_streaming(self) -> Tuple[bool, str]:
if not self._streaming_enabled:
return False, ""
self._streaming_enabled = False
with self._lock:
final_text = self._buffer
self._buffer = ""
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
return True, final_text
def _sse_payload(data: dict) -> str:
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _stream_response(
agent: _CollectingMoviePilotAgent,
prompt: str,
images: List[str],
) -> AsyncIterator[str]:
event_queue: asyncio.Queue = asyncio.Queue()
if isinstance(agent.stream_handler, _OpenAIStreamingHandler):
agent.stream_handler.bind_queue(event_queue)
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
finished = False
async def _run_agent():
try:
await agent.process(prompt, images=images, files=None)
except Exception as exc:
await event_queue.put({"error": str(exc)})
finally:
await event_queue.put(None)
task = asyncio.create_task(_run_agent())
try:
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
while True:
item = await event_queue.get()
if item is None:
break
if isinstance(item, dict) and item.get("error"):
raise RuntimeError(str(item["error"]))
text = str(item or "")
if not text:
continue
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {"content": text},
"finish_reason": None,
}
],
}
)
finished = True
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop",
}
],
}
)
yield "data: [DONE]\n\n"
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
elif finished:
await task
def _error_response(
message: str,
status_code: int,
error_type: str = "invalid_request_error",
code: Optional[str] = None,
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content=schemas.OpenAIErrorResponse(
error=schemas.OpenAIErrorDetail(
message=message,
type=error_type,
code=code,
)
).model_dump(),
headers={"WWW-Authenticate": "Bearer"},
)
def _check_auth(
credentials: Optional[HTTPAuthorizationCredentials],
) -> Optional[JSONResponse]:
if not credentials or credentials.scheme.lower() != "bearer":
return _error_response(
"Invalid bearer token.",
401,
error_type="authentication_error",
code="invalid_api_key",
)
if credentials.credentials != settings.API_TOKEN:
return _error_response(
"Invalid bearer token.",
401,
error_type="authentication_error",
code="invalid_api_key",
)
return None
@router.get("/models", summary="OpenAI compatible models", response_model=schemas.OpenAIModelListResponse)
async def list_models(
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
now = int(time.time())
return schemas.OpenAIModelListResponse(
data=[schemas.OpenAIModelInfo(id=MODEL_ID, created=now)]
)
@router.post(
"/chat/completions",
summary="OpenAI compatible chat completions",
response_model=schemas.OpenAIChatCompletionResponse,
)
async def chat_completions(
payload: schemas.OpenAIChatCompletionsRequest,
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="server_error",
code="ai_agent_disabled",
)
if not payload.messages:
return _error_response(
"`messages` must be a non-empty array.",
400,
code="invalid_messages",
)
session_key = (
str(payload.user or "").strip()
or str(request.headers.get("x-session-id") or "").strip()
or str(uuid.uuid4())
)
use_server_session = bool(
str(payload.user or "").strip()
or str(request.headers.get("x-session-id") or "").strip()
)
try:
prompt, images = build_prompt(payload.messages, use_server_session=use_server_session)
except ValueError as exc:
return _error_response(str(exc), 400, code="invalid_messages")
session_id = build_session_id(session_key, SESSION_PREFIX)
username = str(payload.user or "openai-client")
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_key,
channel=MessageChannel.Web.value,
source="openai",
username=username,
stream_mode=payload.stream,
)
if payload.stream:
return StreamingResponse(
_stream_response(agent=agent, prompt=prompt, images=images),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _error_response(
str(exc),
500,
error_type="server_error",
code="agent_execution_failed",
)
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
return JSONResponse(content=build_completion_payload(content, MODEL_ID))
@router.post("/responses", summary="OpenAI compatible responses", response_model=schemas.OpenAIResponsesResponse)
async def responses(
payload: schemas.OpenAIResponsesRequest,
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="server_error",
code="ai_agent_disabled",
)
if payload.stream:
return _error_response(
"Streaming is not supported for /responses yet.",
400,
code="unsupported_stream",
)
normalized_messages = build_responses_input(payload.input, instructions=payload.instructions)
if not normalized_messages:
return _error_response(
"`input` must include at least one usable message.",
400,
code="invalid_input",
)
try:
prompt, images = build_prompt(normalized_messages, use_server_session=bool(payload.user))
except ValueError as exc:
return _error_response(str(exc), 400, code="invalid_input")
session_key = str(payload.user or uuid.uuid4())
session_id = build_session_id(session_key, SESSION_PREFIX)
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_key,
channel=MessageChannel.Web.value,
source="openai.responses",
username=str(payload.user or "openai-client"),
stream_mode=False,
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _error_response(
str(exc),
500,
error_type="server_error",
code="agent_execution_failed",
)
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
created_at = int(time.time())
response_id = f"resp_{uuid.uuid4().hex}"
output_message = schemas.OpenAIResponsesOutputMessage(
id=f"msg_{uuid.uuid4().hex}",
content=[schemas.OpenAIResponsesOutputText(text=content)],
)
return schemas.OpenAIResponsesResponse(
id=response_id,
created_at=created_at,
model=MODEL_ID,
output=[output_message],
usage=schemas.OpenAIUsage(),
)

View File

@@ -12,6 +12,7 @@ from anyio import Path as AsyncPath
from app.helper.sites import SitesHelper # noqa # noqa
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app import schemas
from app.chain.mediaserver import MediaServerChain
@@ -29,14 +30,14 @@ from app.db.user_oper import (
get_current_active_superuser_async,
get_current_active_user_async,
)
from app.helper.llm import LLMHelper, LLMTestError, LLMTestTimeout
from app.helper.image import ImageHelper
from app.helper.llm import LLMHelper, LLMTestTimeout
from app.helper.mediaserver import MediaServerHelper
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
from app.helper.rule import RuleHelper
from app.helper.subscribe import SubscribeHelper
from app.helper.system import SystemHelper
from app.helper.image import ImageHelper
from app.log import logger
from app.scheduler import Scheduler
from app.schemas import ConfigChangeEventData
@@ -45,7 +46,6 @@ from app.utils.crypto import HashUtils
from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.security import SecurityUtils
from app.utils.url import UrlUtils
from pydantic import BaseModel
from version import APP_VERSION
router = APIRouter()
@@ -57,7 +57,7 @@ class LlmTestRequest(BaseModel):
enabled: Optional[bool] = None
provider: Optional[str] = None
model: Optional[str] = None
disable_thinking: Optional[bool] = None
thinking_level: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
@@ -269,74 +269,6 @@ def _build_nettest_rules() -> list[dict[str, Any]]:
return rules
def _build_llm_test_data(
duration_ms: Optional[int] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
) -> dict[str, Any]:
"""
构造 LLM 测试接口的基础返回数据。
"""
data = {
"provider": provider if provider is not None else settings.LLM_PROVIDER,
"model": model if model is not None else settings.LLM_MODEL,
}
if duration_ms is not None:
data["duration_ms"] = duration_ms
return data
def _normalize_llm_test_value(
value: Optional[str], *, empty_as_none: bool = False
) -> Optional[str]:
"""
清理来自前端的 LLM 测试字段。
"""
if value is None:
return None
stripped = value.strip()
if empty_as_none and not stripped:
return None
return stripped
def _build_llm_test_snapshot(payload: Optional[LlmTestRequest] = None) -> dict[str, Any]:
"""
冻结当前 LLM 测试所需配置。
优先使用前端传入的临时参数;未传入时回退到已保存配置,兼容旧调用。
"""
provider = settings.LLM_PROVIDER
model = settings.LLM_MODEL
disable_thinking = bool(getattr(settings, "LLM_DISABLE_THINKING", False))
api_key = settings.LLM_API_KEY
base_url = settings.LLM_BASE_URL
enabled = bool(settings.AI_AGENT_ENABLE)
if payload:
if payload.enabled is not None:
enabled = bool(payload.enabled)
if payload.provider is not None:
provider = _normalize_llm_test_value(payload.provider) or ""
if payload.model is not None:
model = _normalize_llm_test_value(payload.model) or ""
if payload.disable_thinking is not None:
disable_thinking = bool(payload.disable_thinking)
if payload.api_key is not None:
api_key = _normalize_llm_test_value(payload.api_key, empty_as_none=True)
if payload.base_url is not None:
base_url = _normalize_llm_test_value(payload.base_url, empty_as_none=True)
return {
"enabled": enabled,
"provider": provider,
"model": model,
"disable_thinking": disable_thinking,
"api_key": api_key,
"base_url": base_url,
}
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
"""
清理错误信息中的敏感字段,避免回显密钥。
@@ -428,12 +360,12 @@ async def _close_nettest_response(response: Any) -> None:
async def fetch_image(
url: str,
proxy: Optional[bool] = None,
use_cache: bool = False,
if_none_match: Optional[str] = None,
cookies: Optional[str | dict] = None,
allowed_domains: Optional[set[str]] = None,
url: str,
proxy: Optional[bool] = None,
use_cache: bool = False,
if_none_match: Optional[str] = None,
cookies: Optional[str | dict] = None,
allowed_domains: Optional[set[str]] = None,
) -> Optional[Response]:
"""
处理图片缓存逻辑支持HTTP缓存和磁盘缓存
@@ -455,6 +387,7 @@ async def fetch_image(
use_cache=use_cache,
cookies=cookies,
)
if content:
# 检查 If-None-Match
etag = HashUtils.md5(content)
@@ -467,16 +400,17 @@ async def fetch_image(
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
headers=headers,
)
return None
@router.get("/img/{proxy}", summary="图片代理")
async def proxy_img(
imgurl: str,
proxy: bool = False,
cache: bool = False,
use_cookies: bool = False,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token),
imgurl: str,
proxy: bool = False,
cache: bool = False,
use_cookies: bool = False,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token),
) -> Response:
"""
图片代理,可选是否使用代理服务器,支持 HTTP 缓存
@@ -505,9 +439,9 @@ async def proxy_img(
@router.get("/cache/image", summary="图片缓存")
async def cache_img(
url: str,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token),
url: str,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token),
) -> Response:
"""
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
@@ -601,7 +535,7 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
async def set_env_setting(
env: dict, _: User = Depends(get_current_active_superuser_async)
env: dict, _: User = Depends(get_current_active_superuser_async)
):
"""
更新系统环境变量(仅管理员)
@@ -636,9 +570,9 @@ async def set_env_setting(
@router.get("/progress/{process_type}", summary="实时进度")
async def get_progress(
request: Request,
process_type: str,
_: schemas.TokenPayload = Depends(verify_resource_token),
request: Request,
process_type: str,
_: schemas.TokenPayload = Depends(verify_resource_token),
):
"""
实时获取处理进度返回格式为SSE
@@ -673,9 +607,9 @@ async def get_setting(key: str, _: User = Depends(get_current_active_user_async)
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
async def set_setting(
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
_: User = Depends(get_current_active_superuser_async),
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
_: User = Depends(get_current_active_superuser_async),
):
"""
更新系统设置(仅管理员)
@@ -709,10 +643,10 @@ async def set_setting(
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
async def get_llm_models(
provider: str,
api_key: str,
base_url: Optional[str] = None,
_: User = Depends(get_current_active_user_async),
provider: str,
api_key: str,
base_url: Optional[str] = None,
_: User = Depends(get_current_active_user_async),
):
"""
获取LLM模型列表
@@ -728,28 +662,33 @@ async def get_llm_models(
@router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response)
async def llm_test(
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
_: User = Depends(get_current_active_superuser_async),
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
_: User = Depends(get_current_active_superuser_async),
):
"""
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
"""
snapshot = _build_llm_test_snapshot(payload)
data = _build_llm_test_data(
provider=snapshot["provider"],
model=snapshot["model"],
)
if not snapshot["enabled"]:
if not payload:
return schemas.Response(success=False, message="请配置智能助手LLM相关参数后再进行测试")
if not payload.provider or not payload.model:
return schemas.Response(success=False, message="请配置LLM提供商和模型")
data = {
"provider": payload.provider,
"model": payload.model,
}
if not payload.enabled:
return schemas.Response(success=False, message="请先启用智能助手", data=data)
if not snapshot["api_key"]:
if not payload.api_key or not payload.api_key.strip():
return schemas.Response(
success=False,
message="请先配置 LLM API Key",
data=data,
)
if not (snapshot["model"] or "").strip():
if not payload.model or not payload.model.strip():
return schemas.Response(
success=False,
message="请先配置 LLM 模型",
@@ -758,50 +697,36 @@ async def llm_test(
try:
result = await LLMHelper.test_current_settings(
provider=snapshot["provider"],
model=snapshot["model"],
disable_thinking=snapshot["disable_thinking"],
api_key=snapshot["api_key"],
base_url=snapshot["base_url"],
provider=payload.provider,
model=payload.model,
thinking_level=payload.thinking_level,
api_key=payload.api_key,
base_url=payload.base_url,
)
if not result.get("reply_preview"):
return schemas.Response(
success=False,
message="模型响应为空",
data=_build_llm_test_data(
result.get("duration_ms"),
provider=snapshot["provider"],
model=snapshot["model"],
),
message="模型响应为空"
)
return schemas.Response(success=True, data=result)
except (LLMTestTimeout, TimeoutError) as err:
logger.warning(err)
return schemas.Response(
success=False,
message="LLM 调用超时",
data=_build_llm_test_data(
getattr(err, "duration_ms", None),
provider=snapshot["provider"],
model=snapshot["model"],
),
message="LLM 调用超时"
)
except Exception as err:
return schemas.Response(
success=False,
message=_sanitize_llm_test_error(str(err), snapshot["api_key"]),
data=_build_llm_test_data(
getattr(err, "duration_ms", None),
provider=snapshot["provider"],
model=snapshot["model"],
),
message=_sanitize_llm_test_error(str(err), payload.api_key)
)
@router.get("/message", summary="实时消息")
async def get_message(
request: Request,
role: Optional[str] = "system",
_: schemas.TokenPayload = Depends(verify_resource_token),
request: Request,
role: Optional[str] = "system",
_: schemas.TokenPayload = Depends(verify_resource_token),
):
"""
实时获取系统消息返回格式为SSE
@@ -824,10 +749,10 @@ async def get_message(
@router.get("/logging", summary="实时日志")
async def get_logging(
request: Request,
length: Optional[int] = 50,
logfile: Optional[str] = "moviepilot.log",
_: schemas.TokenPayload = Depends(verify_resource_token),
request: Request,
length: Optional[int] = 50,
logfile: Optional[str] = "moviepilot.log",
_: schemas.TokenPayload = Depends(verify_resource_token),
):
"""
实时获取系统日志
@@ -838,7 +763,7 @@ async def get_logging(
log_path = base_path / logfile
if not await SecurityUtils.async_is_safe_path(
base_path=base_path, user_path=log_path, allowed_suffixes={".log"}
base_path=base_path, user_path=log_path, allowed_suffixes={".log"}
):
raise HTTPException(status_code=404, detail="Not Found")
@@ -855,7 +780,7 @@ async def get_logging(
# 读取历史日志
async with aiofiles.open(
log_path, mode="r", encoding="utf-8", errors="ignore"
log_path, mode="r", encoding="utf-8", errors="ignore"
) as f:
# 优化大文件读取策略
if file_size > 100 * 1024:
@@ -867,7 +792,7 @@ async def get_logging(
# 找到第一个完整的行
first_newline = content.find("\n")
if first_newline != -1:
content = content[first_newline + 1 :]
content = content[first_newline + 1:]
else:
# 小文件直接读取全部内容
content = await f.read()
@@ -875,7 +800,7 @@ async def get_logging(
# 按行分割并添加到队列,只保留非空行
lines = [line.strip() for line in content.splitlines() if line.strip()]
# 只取最后N行
for line in lines[-max(length, 50) :]:
for line in lines[-max(length, 50):]:
lines_queue.append(line)
# 输出历史日志
@@ -884,7 +809,7 @@ async def get_logging(
# 实时监听新日志
async with aiofiles.open(
log_path, mode="r", encoding="utf-8", errors="ignore"
log_path, mode="r", encoding="utf-8", errors="ignore"
) as f:
# 移动文件指针到文件末尾,继续监听新增内容
await f.seek(0, 2)
@@ -923,7 +848,7 @@ async def get_logging(
try:
# 使用 aiofiles 异步读取文件
async with aiofiles.open(
log_path, mode="r", encoding="utf-8", errors="ignore"
log_path, mode="r", encoding="utf-8", errors="ignore"
) as file:
text = await file.read()
# 倒序输出
@@ -955,10 +880,10 @@ async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
@router.get("/ruletest", summary="过滤规则测试", response_model=schemas.Response)
def ruletest(
title: str,
rulegroup_name: str,
subtitle: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token),
title: str,
rulegroup_name: str,
subtitle: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token),
):
"""
过滤规则测试,规则类型 1-订阅2-洗版3-搜索
@@ -1013,11 +938,10 @@ async def nettest_targets(_: schemas.TokenPayload = Depends(verify_token)):
@router.get("/nettest", summary="测试网络连通性")
async def nettest(
target_id: Optional[str] = None,
url: Optional[str] = None,
proxy: Optional[bool] = None,
include: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token),
target_id: Optional[str] = None,
url: Optional[str] = None,
include: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token),
):
"""
测试内置目标的网络连通性。

177
app/api/openai_utils.py Normal file
View File

@@ -0,0 +1,177 @@
import hashlib
import time
import uuid
from typing import Any, Dict, List, Tuple
def _get_message_field(message: Any, field: str, default: Any = None) -> Any:
if isinstance(message, dict):
return message.get(field, default)
return getattr(message, field, default)
def extract_text_and_images(content: Any) -> Tuple[str, List[str]]:
if content is None:
return "", []
if isinstance(content, str):
return content.strip(), []
text_parts: List[str] = []
image_urls: List[str] = []
if isinstance(content, list):
for item in content:
if isinstance(item, str):
normalized = item.strip()
if normalized:
text_parts.append(normalized)
continue
if not isinstance(item, dict):
continue
item_type = (item.get("type") or "").lower()
if item_type == "text":
text = item.get("text")
if text and str(text).strip():
text_parts.append(str(text).strip())
elif item_type == "input_text":
text = item.get("text")
if text and str(text).strip():
text_parts.append(str(text).strip())
elif item_type == "image_url":
image_url = item.get("image_url")
url = image_url.get("url") if isinstance(image_url, dict) else image_url
if url and str(url).strip():
image_urls.append(str(url).strip())
elif item_type == "input_image":
url = item.get("image_url")
if url and str(url).strip():
image_urls.append(str(url).strip())
elif item_type == "image":
source = item.get("source") or {}
if isinstance(source, dict) and source.get("type") == "base64":
data = source.get("data")
media_type = source.get("media_type") or "image/png"
if data and str(data).strip():
image_urls.append(f"data:{media_type};base64,{str(data).strip()}")
return "\n".join(text_parts).strip(), image_urls
def build_prompt(messages: List[Any], use_server_session: bool) -> Tuple[str, List[str]]:
system_texts: List[str] = []
transcript: List[str] = []
latest_user_text = ""
latest_user_images: List[str] = []
for message in messages:
role = str(_get_message_field(message, "role", "user") or "user").lower()
if role == "developer":
role = "system"
text, images = extract_text_and_images(_get_message_field(message, "content"))
if role == "system":
if text:
system_texts.append(text)
continue
if role == "user":
if text or images:
latest_user_text = text
latest_user_images = images
if text:
transcript.append(f"user: {text}")
continue
if text:
transcript.append(f"{role}: {text}")
if not latest_user_text and not latest_user_images:
raise ValueError("No usable user message found in messages.")
prompt_parts: List[str] = []
if system_texts:
prompt_parts.append("系统要求:\n" + "\n\n".join(system_texts))
if not use_server_session and transcript:
history = transcript[:-1] if transcript[-1].startswith("user: ") else transcript
if history:
prompt_parts.append("对话上下文:\n" + "\n".join(history[-10:]))
if latest_user_text:
prompt_parts.append("当前用户消息:\n" + latest_user_text)
else:
prompt_parts.append("当前用户消息:\n请结合图片内容回复。")
return "\n\n".join(part for part in prompt_parts if part).strip(), latest_user_images
def build_session_id(session_key: str, prefix: str) -> str:
digest = hashlib.sha256(session_key.encode("utf-8")).hexdigest()
return f"{prefix}{digest[:32]}"
def build_completion_payload(content: str, model_id: str) -> Dict[str, Any]:
created = int(time.time())
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": created,
"model": model_id,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
def build_responses_input(
input_data: Any, instructions: str | None = None
) -> List[Dict[str, Any]]:
messages: List[Dict[str, Any]] = []
if instructions and str(instructions).strip():
messages.append({"role": "system", "content": str(instructions).strip()})
if isinstance(input_data, str):
normalized = input_data.strip()
if normalized:
messages.append({"role": "user", "content": normalized})
return messages
if isinstance(input_data, list):
for item in input_data:
if not isinstance(item, dict):
continue
item_type = (item.get("type") or "").lower()
if item_type == "message":
role = item.get("role") or "user"
content = item.get("content")
messages.append({"role": role, "content": content})
elif item.get("role") and "content" in item:
messages.append({"role": item.get("role"), "content": item.get("content")})
return messages
if isinstance(input_data, dict) and input_data.get("role") and "content" in input_data:
messages.append({"role": input_data.get("role"), "content": input_data.get("content")})
return messages
def build_anthropic_messages(
system: Any, messages: List[Any]
) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
system_text, _ = extract_text_and_images(system)
if system_text:
normalized.append({"role": "system", "content": system_text})
for message in messages:
role = _get_message_field(message, "role", "user")
content = _get_message_field(message, "content")
normalized.append({"role": role, "content": content})
return normalized

View File

@@ -1407,6 +1407,7 @@ class ChainBase(metaclass=ABCMeta):
chat_id: Union[str, int],
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
) -> bool:
"""
编辑已发送的消息
@@ -1416,6 +1417,7 @@ class ChainBase(metaclass=ABCMeta):
:param chat_id: 聊天ID
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 更新后的按钮列表
:return: 编辑是否成功
"""
return self.run_module(
@@ -1426,6 +1428,7 @@ class ChainBase(metaclass=ABCMeta):
chat_id=chat_id,
text=text,
title=title,
buttons=buttons,
)
def send_direct_message(self, message: Notification) -> Optional[MessageResponse]:

1363
app/chain/interaction.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -566,8 +566,8 @@ class SearchChain(ChainBase):
) or []
)
search_count += 1
# 有结果则停止
if torrents:
# 未开启多名称搜索时,有结果则停止
if not settings.SEARCH_MULTIPLE_NAME and torrents:
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
break
@@ -654,7 +654,7 @@ class SearchChain(ChainBase):
}
search_count += 1
if torrents:
if not settings.SEARCH_MULTIPLE_NAME and torrents:
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
break

1241
app/chain/skills.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,7 @@ from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.message import MessageChain
from app.chain.site import SiteChain
from app.chain.skills import SkillsChain
from app.chain.subscribe import SubscribeChain
from app.chain.system import SystemChain
from app.chain.transfer import TransferChain
@@ -154,6 +155,18 @@ class Command(metaclass=Singleton):
"category": "管理",
"data": {},
},
"/session_status": {
"func": MessageChain().remote_session_status,
"description": "会话状态",
"category": "智能体",
"data": {},
},
"/skills": {
"func": SkillsChain().remote_manage,
"description": "管理技能",
"category": "智能体",
"data": {},
},
}
# 插件命令集合
self._plugin_commands = {}

View File

@@ -420,6 +420,15 @@ class ConfigModel(BaseModel):
# 本地插件仓库目录,多个地址使用,分隔
PLUGIN_LOCAL_REPO_PATHS: Optional[str] = None
# ==================== 技能配置 ====================
# 技能市场仓库地址,多个地址使用,分隔
SKILL_MARKET: str = (
"https://clawhub.ai,"
"https://github.com/openai/skills,"
"https://github.com/anthropics/skills,"
"https://github.com/vercel-labs/agent-skills"
)
# ==================== Github & PIP ====================
# Github token提高请求api限流阈值 ghp_****
GITHUB_TOKEN: Optional[str] = None
@@ -496,8 +505,8 @@ class ConfigModel(BaseModel):
LLM_PROVIDER: str = "deepseek"
# LLM模型名称
LLM_MODEL: str = "deepseek-chat"
# 是否尽量关闭模型的思考/推理能力(按各 provider/model 支持情况自动适配)
LLM_DISABLE_THINKING: bool = True
# 思考模式/深度配置off/auto/minimal/low/medium/high/max/xhigh
LLM_THINKING_LEVEL: Optional[str] = 'off'
# LLM是否支持图片输入开启后消息图片会按多模态输入发送给模型
LLM_SUPPORT_IMAGE_INPUT: bool = True
# LLM API密钥

View File

@@ -13,7 +13,7 @@ from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from cryptography.fernet import Fernet
from fastapi import HTTPException, status, Security, Request, Response
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie, HTTPBearer
from passlib.context import CryptContext
from app import schemas
@@ -42,6 +42,12 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
# API KEY 通过 QUERY 认证
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
# OpenAI compatible Bearer Token 认证
openai_bearer_scheme = HTTPBearer(auto_error=False)
# Anthropic compatible API Key 认证
anthropic_api_key_header = APIKeyHeader(name="x-api-key", auto_error=False, scheme_name="anthropic_api_key_header")
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None

View File

@@ -2,9 +2,13 @@
import asyncio
import inspect
import json
import time
from functools import wraps
from typing import Any, List
from langchain_core.messages import AIMessage
from app.core.config import settings
from app.log import logger
@@ -70,21 +74,120 @@ def _get_httpx_proxy_key() -> str:
if "proxy" in params:
return "proxy"
return "proxies"
except Exception:
except Exception as e:
logger.warning(f"检测 httpx 代理参数失败,默认使用 'proxies'{e}")
return "proxies"
def _deepseek_thinking_toggle(extra_body: Any) -> bool | None:
"""
解析 DeepSeek extra_body 中显式传入的 thinking 开关。
"""
if not isinstance(extra_body, dict):
return None
thinking = extra_body.get("thinking")
if not isinstance(thinking, dict):
return None
thinking_type = str(thinking.get("type") or "").strip().lower()
if thinking_type == "enabled":
return True
if thinking_type == "disabled":
return False
return None
def _is_deepseek_thinking_enabled(model_name: str | None, extra_body: Any) -> bool:
"""
判断本次 DeepSeek 调用是否处于 thinking mode。
"""
explicit_toggle = _deepseek_thinking_toggle(extra_body)
if explicit_toggle is not None:
return explicit_toggle
normalized_model_name = str(model_name or "").strip().lower()
if normalized_model_name == "deepseek-reasoner":
return True
if normalized_model_name.startswith("deepseek-v4-"):
# DeepSeek V4 默认启用 thinking mode除非显式关闭。
return True
return False
def _patch_deepseek_reasoning_content_support():
"""
修补 langchain-deepseek 在 tool-call 场景下遗漏 reasoning_content 回传的问题。
DeepSeek thinking mode 要求:若 assistant 历史消息包含 tool_calls
后续请求中必须带回该条消息的顶层 reasoning_content。
某些 langchain-deepseek 版本虽然能从响应中拿到 reasoning_content
但不会在重放消息历史时写回请求载荷,导致 400。
"""
try:
from langchain_deepseek import ChatDeepSeek
except Exception as err:
logger.debug(f"跳过 langchain-deepseek reasoning_content 修补:{err}")
return
if getattr(ChatDeepSeek, "_moviepilot_reasoning_content_patched", False):
return
original_get_request_payload = getattr(ChatDeepSeek, "_get_request_payload", None)
if not callable(original_get_request_payload):
logger.warning("langchain-deepseek 缺少 _get_request_payload无法修补 reasoning_content")
return
@wraps(original_get_request_payload)
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
# Resolve original messages so we can extract reasoning_content from
# additional_kwargs. The parent's payload builder does not propagate
# this DeepSeek-specific field.
messages = self._convert_input(input_).to_messages()
for i, message in enumerate(payload["messages"]):
if message["role"] == "tool" and isinstance(message["content"], list):
message["content"] = json.dumps(message["content"])
elif message["role"] == "assistant":
if isinstance(message["content"], list):
# DeepSeek API expects assistant content to be a string,
# not a list. Extract text blocks and join them, or use
# empty string if none exist.
text_parts = [
block.get("text", "")
for block in message["content"]
if isinstance(block, dict) and block.get("type") == "text"
]
message["content"] = "".join(text_parts) if text_parts else ""
# DeepSeek reasoning models require every assistant message to
# carry a reasoning_content field (even when empty). The value
# is stored in AIMessage.additional_kwargs by
# _create_chat_result(); re-inject it into the API payload.
if (
"reasoning_content" not in message
and i < len(messages)
and isinstance(messages[i], AIMessage)
):
message["reasoning_content"] = messages[i].additional_kwargs.get(
"reasoning_content", ""
)
return payload
ChatDeepSeek._get_request_payload = _patched_get_request_payload
ChatDeepSeek._moviepilot_reasoning_content_patched = True
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
class LLMHelper:
"""LLM模型相关辅助功能"""
@staticmethod
def _should_disable_thinking(disable_thinking: bool | None = None) -> bool:
"""
判断本次调用是否应尝试关闭模型思考能力。
"""
if disable_thinking is not None:
return bool(disable_thinking)
return bool(getattr(settings, "LLM_DISABLE_THINKING", False))
_SUPPORTED_THINKING_LEVELS = frozenset(
{"off", "auto", "minimal", "low", "medium", "high", "max", "xhigh"}
)
@staticmethod
def _normalize_model_name(model_name: str | None) -> str:
@@ -94,48 +197,164 @@ class LLMHelper:
return (model_name or "").strip().lower()
@classmethod
def _build_disabled_thinking_kwargs(
cls,
provider: str,
model: str | None,
disable_thinking: bool | None = None,
def _normalize_deepseek_reasoning_effort(
cls, thinking_level: str | None = None
) -> str | None:
"""
DeepSeek 文档当前建议使用 high/max兼容常见 effort 别名。
"""
if not thinking_level or thinking_level in {"off", "auto"}:
return None
if thinking_level in {"minimal", "low", "medium", "high"}:
return "high"
if thinking_level in {"max", "xhigh"}:
return "max"
logger.warning(f"忽略不支持的 DeepSeek reasoning_effort 配置: {thinking_level}")
return None
@classmethod
def _normalize_openai_reasoning_effort(
cls, thinking_level: str | None = None
) -> str | None:
"""
OpenAI reasoning_effort 支持更细粒度的 effort统一做最近似映射。
"""
if not thinking_level or thinking_level == "auto":
return None
if thinking_level == "off":
return "none"
if thinking_level == "max":
return "xhigh"
return thinking_level
@classmethod
def _build_google_thinking_kwargs(
cls, model_name: str, thinking_level: str
) -> dict[str, Any]:
"""
按 provider/model 生成“禁用思考”相关参数
优先使用 LangChain/OpenAI SDK 已支持的原生字段;仅在 provider
明确要求自定义请求体时,才回退到 extra_body。
Gemini 3 使用 thinking_levelGemini 2.5 使用 thinking_budget
"""
if not cls._should_disable_thinking(disable_thinking):
if not model_name or thinking_level == "auto":
return {}
provider_name = (provider or "").strip().lower()
model_name = cls._normalize_model_name(model)
if not model_name:
return {}
# Moonshot Kimi K2.5/K2.6 需要在请求体显式声明 thinking.disabled。
if model_name.startswith(("kimi-k2.5", "kimi-k2.6")):
return {"extra_body": {"thinking": {"type": "disabled"}}}
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
if provider_name == "openai" and model_name.startswith(
("gpt-5", "o1", "o3", "o4")
):
return {"reasoning_effort": "none"}
# Gemini 使用 google-genai / langchain-google-genai 内置思考控制参数。
if provider_name == "google":
if "gemini-2.5" in model_name:
if "gemini-2.5" in model_name:
if thinking_level == "off":
if "pro" in model_name:
# Gemini 2.5 Pro 官方不支持完全关闭思考,回退到最小预算。
return {
"thinking_budget": 128,
"include_thoughts": False,
}
return {
"thinking_budget": 0,
"include_thoughts": False,
}
if "gemini-3" in model_name:
return {
"thinking_level": "minimal",
budget_map = {
"minimal": 512,
"low": 1024,
"medium": 4096,
"high": 8192,
"max": 24576,
"xhigh": 24576,
}
budget = budget_map.get(thinking_level)
return (
{
"thinking_budget": budget,
"include_thoughts": False,
}
if budget is not None
else {}
)
if "gemini-3" in model_name:
level_map = {
"off": "minimal",
"minimal": "minimal",
"low": "low",
"medium": "medium",
"high": "high",
"max": "high",
"xhigh": "high",
}
google_level = level_map.get(thinking_level)
return (
{
"thinking_level": google_level,
"include_thoughts": False,
}
if google_level
else {}
)
return {}
@classmethod
def _build_kimi_thinking_kwargs(
cls, model_name: str, thinking_level: str
) -> dict[str, Any]:
"""
Kimi 当前公开文档仅支持思考开关,不支持显式深度调节。
"""
if model_name.startswith("kimi-k2-thinking"):
return {}
if thinking_level == "off":
return {"extra_body": {"thinking": {"type": "disabled"}}}
return {}
@classmethod
def _build_thinking_kwargs(
cls,
provider: str,
model: str | None,
thinking_level: str | None = None
) -> dict[str, Any]:
"""
按 provider/model 生成思考模式相关参数。
优先使用 LangChain/OpenAI SDK 已支持的原生字段;仅在 provider
明确要求自定义请求体时,才回退到 extra_body。
"""
provider_name = (provider or "").strip().lower()
model_name = cls._normalize_model_name(model)
if provider_name == "deepseek":
if thinking_level == "off":
return {"extra_body": {"thinking": {"type": "disabled"}}}
if thinking_level == "auto":
return {}
kwargs: dict[str, Any] = {"extra_body": {"thinking": {"type": "enabled"}}}
deepseek_effort = cls._normalize_deepseek_reasoning_effort(
thinking_level
)
if deepseek_effort:
kwargs["reasoning_effort"] = deepseek_effort
return kwargs
if model_name.startswith(("kimi-k2.5", "kimi-k2.6", "kimi-k2-thinking")):
return cls._build_kimi_thinking_kwargs(model_name, thinking_level)
if not model_name:
return {}
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
if provider_name == "openai" and model_name.startswith(
("gpt-5", "o1", "o3", "o4")
):
openai_effort = cls._normalize_openai_reasoning_effort(
thinking_level
)
return {"reasoning_effort": openai_effort} if openai_effort else {}
# Gemini 使用 google-genai / langchain-google-genai 内置思考控制参数。
if provider_name == "google":
return cls._build_google_thinking_kwargs(
model_name, thinking_level
)
return {}
@@ -148,16 +367,26 @@ class LLMHelper:
@staticmethod
def get_llm(
streaming: bool = False,
provider: str | None = None,
model: str | None = None,
disable_thinking: bool | None = None,
api_key: str | None = None,
base_url: str | None = None,
streaming: bool = False,
provider: str | None = None,
model: str | None = None,
thinking_level: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
):
"""
获取LLM实例
:param streaming: 是否启用流式输出
:param provider: LLM提供商默认为配置项LLM_PROVIDER
:param model: 模型名称默认为配置项LLM_MODEL
:param thinking_level: 思考模式级别,默认为 None即自动判断
是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal""low""medium""high""max"/"xhigh"(最大)。
不同模型对思考模式的支持和表现不同,具体映射关系请
参考代码实现。对于不支持思考模式的模型,该参数将被忽略。
:param api_key: API Key默认为
配置项LLM_API_KEY。对于某些提供商
如 DeepSeek可能需要同时提供 base_url。
:param base_url: API Base URL默认为配置项LLM_BASE_URL。
:return: LLM实例
"""
provider_name = str(
@@ -166,10 +395,10 @@ class LLMHelper:
model_name = model if model is not None else settings.LLM_MODEL
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
thinking_kwargs = LLMHelper._build_disabled_thinking_kwargs(
thinking_kwargs = LLMHelper._build_thinking_kwargs(
provider=provider_name,
model=model_name,
disable_thinking=disable_thinking,
thinking_level=thinking_level
)
if not api_key_value:
@@ -201,9 +430,11 @@ class LLMHelper:
elif provider_name == "deepseek":
from langchain_deepseek import ChatDeepSeek
_patch_deepseek_reasoning_content_support()
model = ChatDeepSeek(
model=model_name,
api_key=api_key_value,
api_base=base_url_value,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
@@ -231,7 +462,7 @@ class LLMHelper:
else:
model.profile = {
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
* 1000, # 转换为token单位
* 1000, # 转换为token单位
}
return model
@@ -255,10 +486,10 @@ class LLMHelper:
if isinstance(block, dict) or hasattr(block, "get"):
block_type = block.get("type")
if block.get("thought") or block_type in (
"thinking",
"reasoning_content",
"reasoning",
"thought",
"thinking",
"reasoning_content",
"reasoning",
"thought",
):
continue
if block_type == "text":
@@ -278,13 +509,13 @@ class LLMHelper:
@staticmethod
async def test_current_settings(
prompt: str = "请只回复 OK",
timeout: int = 20,
provider: str | None = None,
model: str | None = None,
disable_thinking: bool | None = None,
api_key: str | None = None,
base_url: str | None = None,
prompt: str = "请只回复 OK",
timeout: int = 20,
provider: str | None = None,
model: str | None = None,
thinking_level: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> dict:
"""
使用当前已保存配置执行一次最小 LLM 调用。
@@ -298,7 +529,7 @@ class LLMHelper:
streaming=False,
provider=provider_name,
model=model_name,
disable_thinking=disable_thinking,
thinking_level=thinking_level,
api_key=api_key_value,
base_url=base_url_value,
)
@@ -326,7 +557,7 @@ class LLMHelper:
return data
def get_models(
self, provider: str, api_key: str, base_url: str = None
self, provider: str, api_key: str, base_url: str = None
) -> List[str]:
"""获取模型列表"""
logger.info(f"获取 {provider} 模型列表...")
@@ -364,7 +595,7 @@ class LLMHelper:
@staticmethod
def _get_openai_compatible_models(
provider: str, api_key: str, base_url: str = None
provider: str, api_key: str, base_url: str = None
) -> List[str]:
"""获取OpenAI兼容模型列表"""
try:

1175
app/helper/skill.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -439,6 +439,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
chat_id: Union[str, int],
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
) -> bool:
"""
编辑消息
@@ -448,6 +449,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
:param chat_id: 聊天ID
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 新的按钮列表
:return: 编辑是否成功
"""
if channel != self._channel:
@@ -460,6 +462,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
result = client.send_msg(
title=title or "",
text=text,
buttons=buttons,
original_message_id=message_id,
original_chat_id=str(chat_id),
)

View File

@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
import json
import re
from urllib.parse import urljoin
from lxml import etree
@@ -11,6 +13,121 @@ from app.utils.string import StringUtils
class NexusAudiencesSiteUserInfo(NexusPhpSiteUserInfo):
schema = SiteSchema.NexusAudiences
def _parse_user_traffic_info(self, html_text):
"""
解析用户流量信息
"""
super()._parse_user_traffic_info(html_text)
self.__parse_userbar_info(html_text)
def _parse_user_detail_info(self, html_text: str):
"""
解析用户额外信息
"""
super()._parse_user_detail_info(html_text)
self.__parse_userbar_info(html_text)
def __parse_userbar_info(self, html_text: str):
"""
解析 Audiences 新版顶部用户栏,覆盖 NexusPHP 通用正则的误判。
"""
html = etree.HTML(html_text)
try:
if not StringUtils.is_valid_html_element(html):
return
for user_node in html.xpath('//*[@data-uploader-url or @data-uploader-stats]'):
self.__parse_user_identity(user_node)
self.__parse_uploader_stats(user_node.get("data-uploader-stats"))
# data-uploader-stats 不包含分享率,需从 compact metric 的 class 中读取。
self.__parse_compact_metric(html, "ratio", "ratio")
self.__parse_compact_metric(html, "uploaded", "upload")
self.__parse_compact_metric(html, "downloaded", "download")
self.__parse_compact_metric(html, "bonus", "bonus")
self.__parse_compact_metric(html, "active", "active")
finally:
if html is not None:
del html
def __parse_user_identity(self, user_node):
"""
从新版用户卡属性中提取用户 ID、用户名和等级。
"""
user_url = user_node.get("data-uploader-url") or ""
user_detail = re.search(r"userdetails\.php\?id=(\d+)", user_url)
if user_detail and user_detail.group(1).strip():
self.userid = user_detail.group(1).strip()
username = user_node.get("data-uploader-label")
if username and username.strip():
self.username = username.strip()
user_level = user_node.get("data-uploader-badge")
if user_level and user_level.strip():
self.user_level = user_level.strip()
def __parse_uploader_stats(self, stats_text: str):
"""
解析 data-uploader-stats 中的结构化流量数据。
"""
if not stats_text:
return
try:
stats = json.loads(stats_text)
except (TypeError, ValueError):
return
if not isinstance(stats, list):
return
for item in stats:
if not isinstance(item, dict):
continue
label = str(item.get("label") or "").strip(" :")
tone = str(item.get("tone") or "").strip()
value = str(item.get("value") or "").strip()
self.__set_metric_value(label=label, tone=tone, value=value)
def __parse_compact_metric(self, html, metric: str, field: str):
"""
按 compact metric 的 class 读取新版用户栏中的单项数据。
"""
values = html.xpath(
f'//*[contains(concat(" ", normalize-space(@class), " "), " site-userbar__compact-metric--{metric} ")]'
'//span[normalize-space()][last()]/text()'
)
if not values:
values = html.xpath(
f'//*[contains(concat(" ", normalize-space(@class), " "), " site-userbar__compact-metric--{metric} ")]'
'/text()'
)
if values:
self.__set_metric_value(field=field, value=values[-1].strip())
def __set_metric_value(self, value: str, label: str = None, tone: str = None, field: str = None):
"""
将 Audiences 用户栏指标写入通用用户数据字段。
"""
if not value:
return
metric_key = field or tone or label
if metric_key in {"uploaded", "上传量", "upload"}:
self.upload = StringUtils.num_filesize(value)
elif metric_key in {"downloaded", "下载量", "download"}:
self.download = StringUtils.num_filesize(value)
elif metric_key in {"bonus", "爆米花"}:
self.bonus = StringUtils.str_float(value)
elif metric_key == "ratio":
self.ratio = StringUtils.str_float(value)
elif metric_key in {"active", "活跃"}:
active_match = re.search(r"\s*(\d+)\s*/\s*↓\s*(\d+)", value)
if active_match:
self.seeding = StringUtils.str_int(active_match.group(1))
self.leeching = StringUtils.str_int(active_match.group(2))
def _parse_seeding_pages(self):
if not self._torrent_seeding_page:
return

View File

@@ -557,6 +557,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
chat_id: Union[str, int],
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
) -> bool:
"""
编辑消息
@@ -566,6 +567,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
:param chat_id: 聊天ID
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 新的按钮列表
:return: 编辑是否成功
"""
if channel != self._channel:
@@ -578,6 +580,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
result = client.send_msg(
title=title or "",
text=text,
buttons=buttons,
original_message_id=str(message_id),
original_chat_id=str(chat_id),
)

View File

@@ -564,6 +564,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
chat_id: Union[str, int],
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
) -> bool:
"""
编辑消息
@@ -573,6 +574,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
:param chat_id: 聊天ID
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 新的按钮列表
:return: 编辑是否成功
"""
if channel != self._channel:
@@ -587,6 +589,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
message_id=message_id,
text=text,
title=title,
buttons=buttons,
)
if result:
return True

View File

@@ -835,6 +835,7 @@ class Telegram:
message_id: Union[str, int],
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
) -> Optional[bool]:
"""
编辑Telegram消息公开方法
@@ -842,6 +843,7 @@ class Telegram:
:param message_id: 消息ID
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 新的按钮列表
:return: 编辑是否成功
"""
if not self._bot:
@@ -861,6 +863,7 @@ class Telegram:
chat_id=str(chat_id),
message_id=int(message_id),
text=caption,
buttons=buttons,
)
except Exception as e:
logger.error(f"编辑Telegram消息异常: {str(e)}")

View File

@@ -11,6 +11,7 @@ from .monitoring import *
from .plugin import *
from .response import *
from .rule import *
from .openai import *
from .servarr import *
from .servcookie import *
from .site import *
@@ -23,4 +24,3 @@ from .transfer import *
from .user import *
from .workflow import *
from .mcp import *

156
app/schemas/openai.py Normal file
View File

@@ -0,0 +1,156 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
class OpenAIModelInfo(BaseModel):
id: str
object: str = "model"
created: int
owned_by: str = "moviepilot"
class OpenAIModelListResponse(BaseModel):
object: str = "list"
data: List[OpenAIModelInfo] = Field(default_factory=list)
class OpenAIChatMessage(BaseModel):
role: str
content: Any
name: Optional[str] = None
model_config = ConfigDict(extra="allow")
class OpenAIChatCompletionsRequest(BaseModel):
model: Optional[str] = None
messages: List[OpenAIChatMessage]
user: Optional[str] = None
stream: bool = False
model_config = ConfigDict(extra="allow")
class OpenAIResponsesRequest(BaseModel):
model: Optional[str] = None
input: Any
instructions: Optional[str] = None
user: Optional[str] = None
stream: bool = False
model_config = ConfigDict(extra="allow")
class OpenAIChatChoiceMessage(BaseModel):
role: str = "assistant"
content: str
class OpenAIChatChoice(BaseModel):
index: int = 0
message: OpenAIChatChoiceMessage
finish_reason: str = "stop"
class OpenAIUsage(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class OpenAIChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[OpenAIChatChoice]
usage: OpenAIUsage
class OpenAIResponsesOutputText(BaseModel):
type: str = "output_text"
text: str
annotations: List[Dict[str, Any]] = Field(default_factory=list)
class OpenAIResponsesOutputMessage(BaseModel):
id: str
type: str = "message"
status: str = "completed"
role: str = "assistant"
content: List[OpenAIResponsesOutputText] = Field(default_factory=list)
class OpenAIResponsesResponse(BaseModel):
id: str
object: str = "response"
created_at: int
status: str = "completed"
model: str
output: List[OpenAIResponsesOutputMessage] = Field(default_factory=list)
error: Optional[Any] = None
incomplete_details: Optional[Any] = None
usage: OpenAIUsage
class OpenAIErrorDetail(BaseModel):
message: str
type: str = "invalid_request_error"
param: Optional[str] = None
code: Optional[str] = None
class OpenAIErrorResponse(BaseModel):
error: OpenAIErrorDetail
OpenAIChatContentPart = Dict[str, Any]
class AnthropicMessage(BaseModel):
role: str
content: Any
model_config = ConfigDict(extra="allow")
class AnthropicMessagesRequest(BaseModel):
model: Optional[str] = None
messages: List[AnthropicMessage]
system: Optional[Any] = None
max_tokens: Optional[int] = 1024
stream: bool = False
model_config = ConfigDict(extra="allow")
class AnthropicTextBlock(BaseModel):
type: str = "text"
text: str
class AnthropicUsage(BaseModel):
input_tokens: int = 0
output_tokens: int = 0
class AnthropicMessagesResponse(BaseModel):
id: str
type: str = "message"
role: str = "assistant"
content: List[AnthropicTextBlock] = Field(default_factory=list)
model: str
stop_reason: str = "end_turn"
stop_sequence: Optional[str] = None
usage: AnthropicUsage = Field(default_factory=AnthropicUsage)
class AnthropicErrorDetail(BaseModel):
type: str = "invalid_request_error"
message: str
class AnthropicErrorResponse(BaseModel):
type: str = "error"
error: AnthropicErrorDetail

View File

@@ -76,14 +76,14 @@ pympler~=1.1
smbprotocol~=1.15.0
setproctitle~=1.3.6
httpx[socks]~=0.28.1
langchain~=1.2.13
langchain-core~=1.2.20
langchain~=1.2.15
langchain-core~=1.3.2
langchain-community~=0.4.1
langchain-openai~=1.1.11
langchain-google-genai~=4.2.1
langchain-openai~=1.2.1
langchain-google-genai~=4.2.2
langchain-deepseek~=1.0.1
langgraph~=1.1.3
openai~=2.29.0
google-genai~=1.68.0
langgraph~=1.1.9
openai~=2.32.0
google-genai~=1.73.1
ddgs~=9.10.0
websocket-client~=1.8.0

View File

@@ -1063,6 +1063,32 @@ def _prompt_choice(label: str, choices: dict[str, str], default: str) -> str:
print("请输入列表中的可选值。")
def _env_llm_thinking_level_default() -> str:
value = _normalize_choice(_env_default("LLM_THINKING_LEVEL", ""))
alias_map = {
"none": "off",
"disabled": "off",
"disable": "off",
"enabled": "auto",
"enable": "auto",
"default": "auto",
"dynamic": "auto",
}
normalized = alias_map.get(value, value)
if normalized in {
"off",
"auto",
"minimal",
"low",
"medium",
"high",
"max",
"xhigh",
}:
return normalized
return "auto"
def _prompt_path(label: str, *, default: Path, allow_empty: bool = False) -> str:
value = _prompt_text(label, default=str(default), allow_empty=allow_empty)
if not value:
@@ -1476,9 +1502,19 @@ def _collect_agent_config() -> dict[str, Any]:
current_value=read_env_value("LLM_API_KEY"),
required=True,
),
"LLM_DISABLE_THINKING": _prompt_yes_no(
"是否尽量关闭模型思考/推理",
default=_env_bool("LLM_DISABLE_THINKING", False),
"LLM_THINKING_LEVEL": _prompt_choice(
"LLM 思考模式/深度",
choices={
"off": "关闭思考",
"auto": "自动",
"minimal": "最小",
"low": "",
"medium": "",
"high": "",
"max": "极高",
"xhigh": "超高",
},
default=_env_llm_thinking_level_default(),
),
"LLM_SUPPORT_IMAGE_INPUT": _prompt_yes_no(
"是否启用图片输入支持",
@@ -1506,7 +1542,7 @@ def _load_auth_site_definitions_inner() -> dict[str, Any]:
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
auth_sites = SitesHelper().get_authsites() or {}
definitions: dict[str, Any] = {}
@@ -1843,7 +1879,7 @@ def _apply_local_system_config_inner(config_payload: dict[str, Any]) -> None:
):
system_config.set(SystemConfigKey.UserSiteAuthParams, site_auth_item)
try:
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
status, msg = SitesHelper().check_user(
site_auth_item.get("site"), site_auth_item.get("params")

View File

@@ -0,0 +1,164 @@
---
name: create-moviepilot-skill
version: 1
description: >-
Use this skill when the user asks to create, scaffold, update, or review a
MoviePilot agent skill. This includes adding a new built-in skill under the
repository `skills/` directory, editing an existing built-in skill, writing
`SKILL.md` frontmatter and workflow instructions, choosing `allowed-tools`,
adding helper scripts when needed, and bumping the built-in skill `version`
so changes can sync into `config/agent/skills`.
allowed-tools: list_directory read_file write_file edit_file execute_command
---
# Create MoviePilot Skill
This skill guides you through creating or updating a built-in MoviePilot agent
skill in this repository.
## Scope
Use this workflow for repository built-in skills:
- Create or update files under `skills/<skill-id>/`
- Commit the skill as part of the MoviePilot repository
- Do not place the implementation only in `config/agent/skills` unless the user
explicitly asks for a local override instead of a built-in skill
## MoviePilot-Specific Rules
- The repository root `skills/` directory is the bundled source of truth for
built-in skills.
- On agent startup, bundled skills are synced into `config/agent/skills`.
- Sync overwrite depends on the `version` field in `SKILL.md`. If you update an
existing built-in skill, increment `version`, or users may continue using an
older copied version.
- Keep the folder name and frontmatter `name` identical. Use lowercase letters,
digits, and hyphens only.
- Prefer extending an existing skill instead of creating an overlapping
duplicate.
## Workflow
### Step 1: Understand the Request
- Determine whether the user wants a new skill or a change to an existing one.
- Extract the target task, likely trigger phrases, needed tools, and whether
helper scripts are necessary.
- If the goal is still ambiguous after reading the request and local context,
ask one focused clarification question. Otherwise proceed with a reasonable
default.
### Step 2: Check Existing Skills First
- Inspect the repository `skills/` directory before creating anything new.
- If an existing skill already covers most of the workflow, update it instead of
adding a near-duplicate.
- Reuse the repository style: concise YAML frontmatter, trigger-rich
description, and procedural body sections.
### Step 3: Choose the Skill ID and Path
- New built-in skill path: `skills/<skill-id>/SKILL.md`
- Keep `<skill-id>` short, hyphen-case, and under 64 characters.
- Use a verb-led or domain-led name that makes the trigger obvious, such as
`transfer-failed-retry`, `moviepilot-api`, or `create-moviepilot-skill`.
### Step 4: Write Frontmatter Correctly
Use this shape:
```markdown
---
name: create-moviepilot-skill
version: 1
description: >-
Explain what the skill does and exactly when to use it.
allowed-tools: list_directory read_file write_file edit_file execute_command
---
```
Rules:
- `description` is the primary trigger surface. Put concrete "when to use"
scenarios there.
- Include `version` for built-in skills. Increment it whenever you ship a new
built-in revision.
- Add `allowed-tools` when the workflow depends on a small, well-defined tool
set.
- Add `compatibility` only when environment constraints actually matter.
### Step 5: Write the Body
The body should contain:
- A short purpose statement
- MoviePilot-specific rules or guardrails
- A step-by-step workflow
- Concrete examples of matching user requests
- References to supporting files when they exist
Prefer:
- Imperative instructions
- Concrete file paths
- Examples aligned with actual MoviePilot conventions
Avoid:
- Generic theory that does not change execution
- Large duplicated documentation
- Extra files like `README.md` or `CHANGELOG.md` inside the skill directory
### Step 6: Add Supporting Files Only When They Help
- Add `scripts/` only when the same deterministic work would otherwise be
rewritten repeatedly.
- Keep helper files inside the same skill directory.
- Reference helper paths explicitly from `SKILL.md`.
- If the skill is instructions-only, keep it to a single `SKILL.md`.
### Step 7: Implement the Skill
For a new built-in skill:
1. Create `skills/<skill-id>/`
2. Create `SKILL.md`
3. Add helper scripts only if they are justified
For an existing built-in skill:
1. Edit `skills/<skill-id>/SKILL.md`
2. Increment `version`
3. Update helper files in the same directory if needed
### Step 8: Validate Before Finishing
- Re-read the frontmatter and confirm `name` matches the directory name.
- Confirm `description` mentions real trigger scenarios.
- If you changed an existing built-in skill, confirm `version` increased.
- If possible, validate the file can be parsed by the MoviePilot skills loader.
- Report the final path and note whether the agent needs a restart to sync the
latest built-in skill into `config/agent/skills`.
## Minimal Example
User request:
`给 MoviePilot agent 加一个处理站点 Cookie 更新的内置技能`
Expected outcome:
- Create or update a directory such as `skills/update-site-cookie/`
- Write `SKILL.md` with a trigger-rich `description`
- Include only the tools needed for that workflow
- Increment `version` when revising an existing built-in skill
## Final Checklist
- Is the skill under the repository `skills/` directory?
- Does the folder name equal frontmatter `name`?
- Does `description` clearly say when the skill should trigger?
- Did you avoid duplicating an existing skill unnecessarily?
- Did you increment `version` for built-in skill updates?
- Did you keep the skill lean and procedural?

View File

@@ -0,0 +1,29 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, patch
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
class TestAgentAddSubscribeTool(unittest.TestCase):
def test_tv_subscription_without_season_reports_default_first_season(self):
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
with patch(
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
new=AsyncMock(return_value=(1, "")),
):
result = asyncio.run(
tool.run(
title="Breaking Bad",
year="2008",
media_type="tv",
)
)
self.assertIn("第1季", result)
self.assertIn("默认按第一季订阅", result)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import (
AskUserChoiceTool,
UserChoiceOptionInput,
)
from app.agent.interaction import (
from app.chain.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)

View File

@@ -0,0 +1,62 @@
import unittest
from unittest.mock import patch
from app.agent.middleware.memory import MEMORY_ONBOARDING_PROMPT
from app.agent.prompt import prompt_manager
from app.core.config import settings
class TestAgentPromptStyle(unittest.TestCase):
def test_agent_prompt_enforces_concise_professional_style(self):
prompt = prompt_manager.get_agent_prompt()
self.assertIn("professional, concise, restrained", prompt)
self.assertIn("Do NOT flatter the user", prompt)
self.assertIn("NO praise, emotional cushioning", prompt)
def test_agent_prompt_defines_tv_subscription_default_season_rule(self):
prompt = prompt_manager.get_agent_prompt()
self.assertIn(
"omitting `season` means subscribe to season 1 only",
prompt,
)
self.assertIn(
"call `add_subscribe` separately for each season",
prompt,
)
def test_non_verbose_prompt_requires_silence_until_all_tools_finish(self):
with patch.object(settings, "AI_AGENT_VERBOSE", False):
prompt = prompt_manager.get_agent_prompt()
self.assertIn(
"[Important Instruction] STRICTLY ENFORCED:",
prompt,
)
self.assertIn(
"DO NOT output any conversational text, explanations, progress updates, or acknowledgements before the first tool call or between tool calls",
prompt,
)
self.assertIn(
"Only then may you send one final user-facing reply",
prompt,
)
def test_verbose_prompt_does_not_inject_silence_until_tools_finish_rule(self):
with patch.object(settings, "AI_AGENT_VERBOSE", True):
prompt = prompt_manager.get_agent_prompt()
self.assertNotIn(
"DO NOT output any conversational text, explanations, progress updates, or acknowledgements before the first tool call or between tool calls",
prompt,
)
def test_memory_onboarding_does_not_force_warm_intro(self):
self.assertIn("Do NOT interrupt the current task", MEMORY_ONBOARDING_PROMPT)
self.assertIn("Do NOT proactively greet warmly", MEMORY_ONBOARDING_PROMPT)
self.assertNotIn("greet the user warmly", MEMORY_ONBOARDING_PROMPT)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,106 @@
import asyncio
import unittest
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage
from app.agent.middleware.usage import UsageMiddleware
from app.chain.message import MessageChain
from app.schemas.types import MessageChannel
class TestAgentSessionStatus(unittest.TestCase):
def test_usage_middleware_records_usage_metadata(self):
snapshots = []
middleware = UsageMiddleware(on_usage=snapshots.append)
request = ModelRequest(
model=SimpleNamespace(
model="gpt-4o-mini", profile={"max_input_tokens": 128000}
),
messages=[],
state={},
runtime=None,
)
response = ModelResponse(
result=[
AIMessage(
content="ok",
usage_metadata={
"input_tokens": 1200,
"output_tokens": 300,
"total_tokens": 1500,
},
)
]
)
async def handler(_: ModelRequest):
return response
result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertIs(result, response)
self.assertEqual(len(snapshots), 1)
self.assertEqual(snapshots[0]["model"], "gpt-4o-mini")
self.assertEqual(snapshots[0]["context_window_tokens"], 128000)
self.assertEqual(snapshots[0]["input_tokens"], 1200)
self.assertEqual(snapshots[0]["output_tokens"], 300)
self.assertEqual(snapshots[0]["total_tokens"], 1500)
self.assertAlmostEqual(snapshots[0]["context_usage_ratio"], 1200 / 128000)
def test_remote_session_status_sends_usage_summary(self):
chain = MessageChain()
chain._user_sessions["10001"] = ("session-1", datetime.now())
status = {
"session_id": "session-1",
"model": "gpt-4o-mini",
"context_window_tokens": 128000,
"last_input_tokens": 1200,
"last_output_tokens": 300,
"last_total_tokens": 1500,
"last_context_usage_ratio": 1200 / 128000,
"total_input_tokens": 4500,
"total_output_tokens": 1500,
"total_tokens": 6000,
"model_call_count": 4,
"last_updated_at": "2026-04-26 12:34:56",
"is_processing": True,
"pending_messages": 2,
}
with (
patch(
"app.chain.message.agent_manager.get_session_status",
return_value=status,
),
patch.object(chain, "post_message") as post_message,
):
chain.remote_session_status(
channel=MessageChannel.Telegram,
userid="10001",
source="telegram-test",
)
notification = post_message.call_args.args[0]
self.assertEqual(notification.title, "当前智能体会话状态")
self.assertIn("session-1", notification.text)
self.assertIn("gpt-4o-mini", notification.text)
self.assertIn("1,200 / 128,000 (0.94%)", notification.text)
self.assertIn("输入 4,500 / 输出 1,500 / 总计 6,000", notification.text)
self.assertIn("运行中", notification.text)
def test_remote_session_status_handles_missing_session(self):
chain = MessageChain()
with patch.object(chain, "post_message") as post_message:
chain.remote_session_status(
channel=MessageChannel.Telegram,
userid="10001",
source="telegram-test",
)
notification = post_message.call_args.args[0]
self.assertEqual(notification.title, "您当前没有活跃的智能体会话")

View File

@@ -0,0 +1,120 @@
import unittest
from unittest.mock import patch
from langchain.agents.middleware import SummarizationMiddleware
import app.agent as agent_module
class _FakeLLM:
_llm_type = "openai-chat"
def __init__(self, model: str):
self.model = model
self.profile = {"max_input_tokens": 64000}
class TestAgentSummarizationStreaming(unittest.TestCase):
def test_streaming_agent_uses_non_streaming_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
):
agent._create_agent(streaming=True)
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
self.assertIs(captured["model"], main_llm)
self.assertIs(summary_middleware.model, non_streaming_llm)
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
class _FakeToolSelectorMiddleware:
def __init__(self, model, max_tools):
self.model = model
self.max_tools = max_tools
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module,
"LLMToolSelectorMiddleware",
_FakeToolSelectorMiddleware,
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
):
agent._create_agent(streaming=True)
tool_selector_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, _FakeToolSelectorMiddleware)
)
self.assertIs(tool_selector_middleware.model, non_streaming_llm)
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
):
agent._create_agent(streaming=False)
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
self.assertIs(captured["model"], main_llm)
self.assertIs(summary_middleware.model, main_llm)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,113 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, patch
import langchain.agents as langchain_agents
if not hasattr(langchain_agents, "create_agent"):
langchain_agents.create_agent = lambda *args, **kwargs: None
from app.agent.callback import StreamingHandler
from app.agent.tools.base import MoviePilotTool
from app.core.config import settings
from app.schemas.message import MessageResponse
from app.schemas.types import MessageChannel
class DummyTool(MoviePilotTool):
name: str = "dummy_tool"
description: str = "Dummy tool for streaming tests."
async def run(self, **kwargs) -> str:
return "ok"
class TestAgentToolStreaming(unittest.TestCase):
async def _run_tool(self, initial_buffer: str) -> tuple[str, str]:
tool = DummyTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
await handler.start_streaming()
if initial_buffer:
handler.emit(initial_buffer)
tool.set_stream_handler(handler)
with patch.object(settings, "AI_AGENT_VERBOSE", False):
result = await tool._arun(explanation="run test tool")
buffered_message = await handler.take()
return result, buffered_message
def test_non_verbose_tool_call_appends_newline_separator(self):
result, buffered_message = asyncio.run(self._run_tool("prefix"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n")
def test_non_verbose_tool_call_does_not_duplicate_newline(self):
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n")
def test_non_verbose_tool_call_keeps_empty_buffer_unchanged(self):
result, buffered_message = asyncio.run(self._run_tool(""))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "")
def test_flush_sends_direct_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._source = "telegram"
handler._user_id = "10001"
handler._username = "tester"
handler._streaming_enabled = True
handler.emit("hello")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
)
self.assertTrue(handler.has_sent_message)
def test_flush_edits_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._streaming_enabled = True
handler._message_response = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
handler._sent_text = "hello"
handler.emit("hello world")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = True
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "edit_message"
)
self.assertEqual(handler._sent_text, "hello world")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,61 @@
import asyncio
import os
import shlex
import subprocess
import sys
import time
import unittest
from app.agent.tools.impl.execute_command import (
ExecuteCommandTool,
MAX_OUTPUT_CHARS,
)
def _python_command(code: str) -> str:
"""生成当前解释器可执行的 shell 命令,避免依赖系统 python 名称。"""
args = [sys.executable, "-c", code]
if os.name == "nt":
return subprocess.list2cmdline(args)
return " ".join(shlex.quote(arg) for arg in args)
class TestExecuteCommandTool(unittest.TestCase):
def _run_command(self, command: str, timeout: int = 60) -> str:
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
return asyncio.run(tool.run(command=command, timeout=timeout))
def test_large_output_is_truncated_before_returning_to_agent(self):
command = _python_command(
"import sys; sys.stdout.write('x' * 200000); sys.stdout.flush()"
)
result = self._run_command(command)
self.assertIn("输出内容过长,已截断", result)
self.assertLess(len(result), MAX_OUTPUT_CHARS + 500)
def test_timeout_returns_partial_output_promptly(self):
command = _python_command(
"import time; print('started', flush=True); time.sleep(5)"
)
started_at = time.monotonic()
result = self._run_command(command, timeout=1)
duration = time.monotonic() - started_at
self.assertLess(duration, 4)
self.assertIn("命令执行超时", result)
self.assertIn("started", result)
def test_timeout_is_capped(self):
command = _python_command("print('ok')")
result = self._run_command(command, timeout=9999)
self.assertIn("timeout 参数超过上限", result)
self.assertIn("ok", result)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,144 @@
import importlib.util
import sys
import unittest
from pathlib import Path
from types import ModuleType
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
def _stub_module(name: str, **attrs):
module = sys.modules.get(name)
if module is None:
module = ModuleType(name)
sys.modules[name] = module
for key, value in attrs.items():
setattr(module, key, value)
return module
class _DummyLogger:
def __getattr__(self, _name):
return lambda *args, **kwargs: None
def _build_tool_call(name: str = "search", arguments: str = "{}"):
return [
{
"id": "call_1",
"type": "tool_call",
"name": name,
"args": {},
}
]
class _FakeChatDeepSeek:
def __init__(self, model_name: str, model_kwargs: dict | None = None):
self.model_name = model_name
self.model_kwargs = model_kwargs or {}
def _get_request_payload(self, input_, *, stop=None, **kwargs):
messages = []
for message in input_:
payload_message = {
"role": message.type,
"content": message.content,
}
if message.type == "human":
payload_message["role"] = "user"
elif message.type == "ai":
payload_message["role"] = "assistant"
tool_calls = getattr(message, "tool_calls", None)
if tool_calls:
payload_message["tool_calls"] = tool_calls
elif message.type == "tool":
payload_message["role"] = "tool"
payload_message["tool_call_id"] = message.tool_call_id
messages.append(payload_message)
return {"messages": messages}
_ORIGINAL_GET_REQUEST_PAYLOAD = _FakeChatDeepSeek._get_request_payload
sys.modules.pop("app.helper.llm", None)
_stub_module(
"app.core.config",
settings=ModuleType("settings"),
)
sys.modules["app.core.config"].settings.LLM_PROVIDER = "deepseek"
sys.modules["app.core.config"].settings.LLM_MODEL = "deepseek-v4-pro"
sys.modules["app.core.config"].settings.LLM_API_KEY = "sk-test"
sys.modules["app.core.config"].settings.LLM_BASE_URL = "https://api.deepseek.com"
sys.modules["app.core.config"].settings.LLM_THINKING_LEVEL = None
sys.modules["app.core.config"].settings.LLM_TEMPERATURE = 0.1
sys.modules["app.core.config"].settings.LLM_MAX_CONTEXT_TOKENS = 64
sys.modules["app.core.config"].settings.PROXY_HOST = None
_stub_module("app.log", logger=_DummyLogger())
_stub_module("langchain_deepseek", ChatDeepSeek=_FakeChatDeepSeek)
module_path = Path(__file__).resolve().parents[1] / "app" / "helper" / "llm.py"
spec = importlib.util.spec_from_file_location("test_llm_module_for_deepseek_compat", module_path)
llm_module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(llm_module)
class DeepSeekCompatPatchTest(unittest.TestCase):
def setUp(self):
_FakeChatDeepSeek._get_request_payload = _ORIGINAL_GET_REQUEST_PAYLOAD
if hasattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched"):
delattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched")
llm_module._patch_deepseek_reasoning_content_support()
def test_injects_reasoning_content_for_assistant_tool_calls(self):
llm = _FakeChatDeepSeek("deepseek-v4-pro")
messages = [
HumanMessage(content="天气如何?"),
AIMessage(
content="",
tool_calls=_build_tool_call(),
additional_kwargs={"reasoning_content": "先调用天气工具"},
),
ToolMessage(content="晴天", tool_call_id="call_1"),
]
payload = llm._get_request_payload(messages)
self.assertEqual(
payload["messages"][1]["reasoning_content"],
"先调用天气工具",
)
def test_falls_back_to_empty_reasoning_content_when_missing(self):
llm = _FakeChatDeepSeek("deepseek-v4-flash")
messages = [
HumanMessage(content="天气如何?"),
AIMessage(content="", tool_calls=_build_tool_call()),
ToolMessage(content="晴天", tool_call_id="call_1"),
]
payload = llm._get_request_payload(messages)
self.assertIn("reasoning_content", payload["messages"][1])
self.assertEqual(payload["messages"][1]["reasoning_content"], "")
def test_skips_injection_when_thinking_is_disabled(self):
llm = _FakeChatDeepSeek(
"deepseek-v4-pro",
model_kwargs={"extra_body": {"thinking": {"type": "disabled"}}},
)
messages = [
HumanMessage(content="天气如何?"),
AIMessage(
content="",
tool_calls=_build_tool_call(),
additional_kwargs={"reasoning_content": "先调用天气工具"},
),
ToolMessage(content="晴天", tool_call_id="call_1"),
]
payload = llm._get_request_payload(messages)
self.assertNotIn("reasoning_content", payload["messages"][1])

View File

@@ -38,7 +38,7 @@ _stub_module(
LLM_MODEL="global-model",
LLM_API_KEY="global-key",
LLM_BASE_URL="https://global.example.com",
LLM_DISABLE_THINKING=False,
LLM_THINKING_LEVEL=None,
LLM_TEMPERATURE=0.1,
LLM_MAX_CONTEXT_TOKENS=64,
PROXY_HOST=None,
@@ -83,7 +83,9 @@ class LlmHelperTestCallTest(unittest.TestCase):
streaming=False,
provider="deepseek",
model="deepseek-chat",
thinking_level=None,
disable_thinking=None,
reasoning_effort=None,
api_key="sk-test",
base_url="https://api.deepseek.com",
)
@@ -138,7 +140,77 @@ class LlmHelperTestCallTest(unittest.TestCase):
{"thinking": {"type": "disabled"}},
)
def test_get_llm_uses_openai_reasoning_effort_none(self):
def test_get_llm_uses_deepseek_thinking_level_controls(self):
calls = []
patch_calls = []
class _FakeChatDeepSeek:
def __init__(self, **kwargs):
calls.append(kwargs)
self.model = kwargs["model"]
self.profile = None
with patch.dict(
sys.modules,
{"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)},
), patch.object(
llm_module,
"_patch_deepseek_reasoning_content_support",
side_effect=lambda: patch_calls.append(True),
):
llm_module.LLMHelper.get_llm(
provider="deepseek",
model="deepseek-v4-pro",
thinking_level="xhigh",
api_key="sk-test",
base_url="https://api.deepseek.com",
)
self.assertEqual(len(calls), 1)
self.assertEqual(
calls[0].get("extra_body"),
{"thinking": {"type": "enabled"}},
)
self.assertEqual(patch_calls, [True])
self.assertEqual(calls[0].get("reasoning_effort"), "max")
self.assertEqual(calls[0].get("api_base"), "https://api.deepseek.com")
def test_get_llm_disables_deepseek_thinking_via_thinking_level(self):
calls = []
patch_calls = []
class _FakeChatDeepSeek:
def __init__(self, **kwargs):
calls.append(kwargs)
self.model = kwargs["model"]
self.profile = None
with patch.dict(
sys.modules,
{"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)},
), patch.object(
llm_module,
"_patch_deepseek_reasoning_content_support",
side_effect=lambda: patch_calls.append(True),
):
llm_module.LLMHelper.get_llm(
provider="deepseek",
model="deepseek-v4-flash",
thinking_level="off",
api_key="sk-test",
base_url="https://proxy.example.com",
)
self.assertEqual(len(calls), 1)
self.assertEqual(
calls[0].get("extra_body"),
{"thinking": {"type": "disabled"}},
)
self.assertEqual(patch_calls, [True])
self.assertIsNone(calls[0].get("reasoning_effort"))
self.assertEqual(calls[0].get("api_base"), "https://proxy.example.com")
def test_get_llm_uses_openai_reasoning_effort_none_for_off(self):
calls = []
class _FakeChatOpenAI:
@@ -154,7 +226,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
llm_module.LLMHelper.get_llm(
provider="openai",
model="gpt-5-mini",
disable_thinking=True,
thinking_level="off",
api_key="sk-test",
base_url="https://api.openai.com/v1",
)
@@ -162,6 +234,30 @@ class LlmHelperTestCallTest(unittest.TestCase):
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].get("reasoning_effort"), "none")
def test_get_llm_maps_unified_max_to_openai_xhigh(self):
calls = []
class _FakeChatOpenAI:
def __init__(self, **kwargs):
calls.append(kwargs)
self.model = kwargs["model"]
self.profile = None
with patch.dict(
sys.modules,
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
):
llm_module.LLMHelper.get_llm(
provider="openai",
model="gpt-5.4",
thinking_level="max",
api_key="sk-test",
base_url="https://api.openai.com/v1",
)
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].get("reasoning_effort"), "xhigh")
def test_get_llm_uses_gemini_builtin_thinking_controls(self):
calls = []
@@ -182,7 +278,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
llm_module.LLMHelper.get_llm(
provider="google",
model="gemini-2.5-flash",
disable_thinking=True,
thinking_level="off",
api_key="sk-test",
base_url=None,
)
@@ -191,6 +287,35 @@ class LlmHelperTestCallTest(unittest.TestCase):
self.assertEqual(calls[0].get("thinking_budget"), 0)
self.assertFalse(calls[0].get("include_thoughts"))
def test_get_llm_uses_gemini_3_thinking_level_controls(self):
calls = []
class _FakeChatGoogleGenerativeAI:
def __init__(self, **kwargs):
calls.append(kwargs)
self.model = kwargs["model"]
self.profile = None
with patch.dict(
sys.modules,
{
"langchain_google_genai": SimpleNamespace(
ChatGoogleGenerativeAI=_FakeChatGoogleGenerativeAI
)
},
):
llm_module.LLMHelper.get_llm(
provider="google",
model="gemini-3.1-flash",
thinking_level="xhigh",
api_key="sk-test",
base_url=None,
)
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].get("thinking_level"), "high")
self.assertFalse(calls[0].get("include_thoughts"))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,158 @@
import sys
import unittest
from types import ModuleType
from unittest.mock import patch
sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi"))
setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list)
sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
setattr(sys.modules["transmission_rpc"], "File", object)
sys.modules.setdefault("psutil", ModuleType("psutil"))
from app.chain.interaction import MediaInteractionChain, media_interaction_manager
from app.chain.message import MessageChain
from app.core.context import MediaInfo
from app.core.meta import MetaBase
from app.schemas.types import MessageChannel
class TestMediaInteraction(unittest.TestCase):
def tearDown(self):
media_interaction_manager.clear()
@staticmethod
def _build_meta(name: str) -> MetaBase:
meta = MetaBase(name)
meta.name = name
meta.begin_season = 1
return meta
def test_message_routes_text_reply_to_media_interaction_before_ai(self):
chain = MessageChain()
request = media_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Wechat,
source="wechat-test",
username="tester",
action="Search",
keyword="星际穿越",
title="星际穿越",
meta=self._build_meta("星际穿越"),
items=[MediaInfo(title="星际穿越", year="2014")],
)
self.assertIsNotNone(request)
with patch.object(chain, "_record_user_message"), patch(
"app.chain.message.MediaInteractionChain.handle_text_interaction",
return_value=True,
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
chain.handle_message(
channel=MessageChannel.Wechat,
source="wechat-test",
userid="10001",
username="tester",
text="1",
)
handle_text.assert_called_once()
handle_ai.assert_not_called()
def test_callback_routes_to_media_interaction_chain(self):
chain = MessageChain()
request = media_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Telegram,
source="telegram-test",
username="tester",
action="Search",
keyword="星际穿越",
title="星际穿越",
meta=self._build_meta("星际穿越"),
items=[MediaInfo(title="星际穿越", year="2014")],
)
with patch(
"app.chain.message.MediaInteractionChain.handle_callback_interaction",
return_value=True,
) as handle_callback:
chain._handle_callback(
text=f"CALLBACK:media:{request.request_id}:page-next",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
)
handle_callback.assert_called_once()
def test_media_interaction_starts_search_and_posts_media_list(self):
chain = MediaInteractionChain()
meta = self._build_meta("星际穿越")
medias = [
MediaInfo(title="星际穿越", year="2014"),
MediaInfo(title="Interstellar", year="2014"),
]
with patch(
"app.chain.interaction.MediaChain.search",
return_value=(meta, medias),
), patch.object(chain, "post_medias_message") as post_medias_message:
handled = chain.handle_text_interaction(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="星际穿越",
)
self.assertTrue(handled)
post_medias_message.assert_called_once()
notification = post_medias_message.call_args.args[0]
self.assertTrue(notification.buttons)
self.assertTrue(
notification.buttons[0][0]["callback_data"].startswith("media:")
)
request = media_interaction_manager.get_by_user("10001")
self.assertIsNotNone(request)
self.assertEqual(request.action, "Search")
self.assertEqual(len(request.items), 2)
def test_media_interaction_legacy_page_callback_updates_existing_request(self):
chain = MediaInteractionChain()
request = media_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Telegram,
source="telegram-test",
username="tester",
action="Search",
keyword="星际穿越",
title="星际穿越",
meta=self._build_meta("星际穿越"),
items=[
MediaInfo(title=f"资源 {index}", year="2024")
for index in range(1, 11)
],
)
with patch.object(chain, "post_medias_message") as post_medias_message:
handled = chain.handle_callback_interaction(
callback_data="page_n",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
original_message_id=123,
original_chat_id="456",
)
self.assertTrue(handled)
self.assertEqual(request.page, 1)
post_medias_message.assert_called_once()
notification = post_medias_message.call_args.args[0]
self.assertEqual(notification.original_message_id, 123)
self.assertEqual(notification.original_chat_id, "456")
if __name__ == "__main__":
unittest.main()

Some files were not shown because too many files have changed in this diff Show More