Refactor movie pilot config and test coverage

This commit is contained in:
jxxghp
2026-06-23 10:05:45 +08:00
parent dc773337d3
commit 0cd049bfc2
40 changed files with 1481 additions and 828 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import hashlib
import json
import re
import traceback
@@ -148,6 +149,16 @@ class _SessionUsageSnapshot:
}
@dataclass
class _CompiledAgentBundle:
"""会话内可复用的 Agent 图及其构造签名。"""
signature: tuple[Any, ...]
agent: Any
streaming: bool
created_at: datetime
class _ThinkTagStripper:
"""
流式剥离 <think>...</think> 标签的辅助类。
@@ -280,6 +291,8 @@ class MoviePilotAgent:
self._llm_runtime_config: Optional[Dict[str, Any]] = None
self._llm_provider_selection: Dict[str, Any] = {}
self._agent_started_at: Optional[datetime] = None
self._compiled_agent_bundle: Optional[_CompiledAgentBundle] = None
self._last_agent_cache_hit = False
# 流式token管理
self.stream_handler = StreamingHandler()
@@ -980,6 +993,90 @@ class MoviePilotAgent:
allow_message_tools=self.allow_message_tools,
)
def _refresh_tool_context(self, values: Dict[str, object]) -> None:
"""
刷新本轮工具共享上下文。
工具对象可能随会话内 Agent 图缓存被复用,因此这里保留 dict 对象本身,
只替换其中内容,确保缓存工具看到的是最新权限与回复状态。
"""
self._tool_context.clear()
self._tool_context.update(values)
@staticmethod
def _public_runtime_config_signature(runtime_config: Dict[str, Any]) -> tuple:
"""生成不包含密钥明文的 LLM 运行时签名。"""
api_key = runtime_config.get("api_key") or ""
api_key_digest = (
hashlib.sha256(str(api_key).encode("utf-8")).hexdigest()[:12]
if api_key
else ""
)
return (
runtime_config.get("provider"),
runtime_config.get("model"),
api_key_digest,
runtime_config.get("base_url"),
runtime_config.get("base_url_preset"),
runtime_config.get("user_agent"),
bool(runtime_config.get("use_proxy")),
runtime_config.get("thinking_level"),
)
async def _agent_bundle_signature(self, streaming: bool) -> tuple[Any, ...]:
"""构造会话内 Agent 图缓存签名。"""
runtime_config = await self._resolve_llm_runtime_config()
return (
streaming,
self.channel,
self.source,
self.user_id,
self.username,
self.allow_message_tools,
bool(self._tool_context.get("is_admin")),
self.has_message_context,
self.is_background,
settings.AI_AGENT_VERBOSE,
settings.LLM_MAX_TOOLS,
settings.LLM_MAX_ITERATIONS,
self._public_runtime_config_signature(runtime_config),
agent_runtime_manager.current_signature(),
)
def _get_cached_agent(
self, signature: tuple[Any, ...], streaming: bool
) -> Optional[Any]:
"""按签名读取当前会话已编译的 Agent 图。"""
bundle = self._compiled_agent_bundle
if (
bundle
and bundle.streaming == streaming
and bundle.signature == signature
):
return bundle.agent
return None
def _cache_agent(
self,
*,
signature: tuple[Any, ...],
agent: Any,
streaming: bool,
) -> Any:
"""保存当前会话可复用的 Agent 图。"""
self._compiled_agent_bundle = _CompiledAgentBundle(
signature=signature,
agent=agent,
streaming=streaming,
created_at=datetime.now(),
)
return agent
@staticmethod
def _latest_turn_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
"""从完整历史中提取本轮新增用户消息。"""
return [messages[-1]] if messages else []
def _initialize_subagent_tools(self) -> List:
"""
初始化子代理专用静默工具列表。
@@ -1006,6 +1103,13 @@ class MoviePilotAgent:
:param streaming: 是否启用流式输出
"""
try:
bundle_signature = await self._agent_bundle_signature(streaming)
cached_agent = self._get_cached_agent(bundle_signature, streaming)
self._last_agent_cache_hit = bool(cached_agent)
if cached_agent:
logger.debug(f"复用会话内 Agent 图: session_id={self.session_id}")
return cached_agent
# 系统提示词
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
@@ -1113,13 +1217,18 @@ class MoviePilotAgent:
)
)
return create_agent(
agent = create_agent(
model=agent_model,
tools=[*tools, *skill_tools, *activity_log_tools],
system_prompt=system_prompt,
middleware=middlewares,
checkpointer=InMemorySaver(),
)
return self._cache_agent(
signature=bundle_signature,
agent=agent,
streaming=streaming,
)
except Exception as e:
logger.error(f"创建 Agent 失败: {e}")
raise e
@@ -1137,12 +1246,15 @@ class MoviePilotAgent:
user_display_saved = False
try:
logger.info(
f"Agent推理: session_id={self.session_id}, input={message}, "
f"Agent推理: session_id={self.session_id}, "
f"input_chars={len(message or '')}, "
f"images={len(images) if images else 0}, files={len(files) if files else 0}, "
f"audio_input={has_audio_input}"
)
self._tool_context = await self._build_tool_context(
should_dispatch_reply=self.should_dispatch_reply
self._refresh_tool_context(
await self._build_tool_context(
should_dispatch_reply=self.should_dispatch_reply
)
)
self._streamed_output = ""
@@ -1330,6 +1442,11 @@ class MoviePilotAgent:
# 创建智能体(根据是否流式传入不同 LLM
agent = await self._create_agent(streaming=use_streaming)
input_messages = (
self._latest_turn_messages(messages)
if self._last_agent_cache_hit
else messages
)
if use_streaming:
self.stream_handler.set_dispatch_policy(
@@ -1348,7 +1465,7 @@ class MoviePilotAgent:
# 流式运行智能体token 直接推送到 stream_handler
await self._stream_agent_tokens(
agent=agent,
messages={"messages": messages},
messages={"messages": input_messages},
config=agent_config,
on_token=self._handle_stream_text,
)
@@ -1387,7 +1504,7 @@ class MoviePilotAgent:
else:
# 非流式模式:后台任务或渠道不支持消息编辑
await agent.ainvoke(
{"messages": messages},
{"messages": input_messages},
config=agent_config,
)
@@ -1446,9 +1563,11 @@ class MoviePilotAgent:
except asyncio.CancelledError:
logger.info(f"Agent执行被取消: session_id={self.session_id}")
self._compiled_agent_bundle = None
execution_error = "任务已取消"
return "任务已取消", {}
except Exception as e:
self._compiled_agent_bundle = None
execution_error = str(e)
if self._messages_have_image_input(messages) and self._is_unsupported_image_input_error(e):
logger.warning(
@@ -1492,6 +1611,7 @@ class MoviePilotAgent:
"""
清理智能体资源
"""
self._compiled_agent_bundle = None
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")

View File

@@ -1628,6 +1628,37 @@ class LLMProviderManager(metaclass=Singleton):
)
return None
def _resolve_cached_model_record(
self,
provider_id: str,
model_id: Optional[str],
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
transport: str = "openai",
) -> dict[str, Any] | None:
"""从缓存中的模型元数据构造轻量模型记录,不触发远端模型列表刷新。"""
if not model_id:
return None
metadata = self.resolve_cached_model_metadata(
provider_id,
model_id,
base_url=base_url,
base_url_preset_id=base_url_preset_id,
) or {}
if not metadata:
return self._normalize_model_record(
model_id=model_id,
transport=transport,
source="configured",
)
return self._normalize_model_record(
model_id=model_id,
display_name=metadata.get("name") or model_id,
metadata=metadata,
transport=transport,
source="models.dev-cache",
)
@staticmethod
def _normalize_model_record(
model_id: str,
@@ -2104,7 +2135,7 @@ class LLMProviderManager(metaclass=Singleton):
try:
return jwt.decode(token, options={"verify_signature": False})
except Exception as err:
print(err)
logger.debug(f"解析 JWT token 内容失败: {err}")
return {}
@staticmethod
@@ -2587,40 +2618,29 @@ class LLMProviderManager(metaclass=Singleton):
)
normalized_api_key = str(api_key or "").strip() or None
normalized_base_url = self._sanitize_base_url(base_url)
model_record = None
if model:
try:
model_record = next(
(
item
for item in await self.list_models(
normalized_provider_id,
api_key=api_key,
base_url=base_url,
base_url_preset_id=normalized_base_url_preset_id,
user_agent=user_agent,
use_proxy=use_proxy,
)
if item["id"] == model
),
None,
)
except Exception as err:
print(err)
model_record = None
default_transport = (
"anthropic" if resolved_runtime == "anthropic_compatible" else "openai"
)
model_record = self._resolve_cached_model_record(
normalized_provider_id,
model,
base_url=base_url,
base_url_preset_id=normalized_base_url_preset_id,
transport=default_transport,
)
model_metadata = self.resolve_cached_model_metadata(
normalized_provider_id,
model,
base_url=base_url,
base_url_preset_id=normalized_base_url_preset_id,
)
result: dict[str, Any] = {
"provider_id": normalized_provider_id,
"runtime": resolved_runtime,
"model_id": model,
"model_record": model_record,
"model_metadata": await self.resolve_model_metadata(
normalized_provider_id,
model,
base_url=base_url,
base_url_preset_id=normalized_base_url_preset_id,
use_proxy=use_proxy,
),
"model_metadata": model_metadata,
"default_headers": None,
"use_responses_api": None,
"auth_mode": "api_key",
@@ -2631,8 +2651,7 @@ class LLMProviderManager(metaclass=Singleton):
try:
auth = await self._resolve_chatgpt_oauth()
except Exception as err:
print(err)
pass
logger.debug(f"解析 ChatGPT OAuth 鉴权失败,回退 API Key 模式: {err}")
if auth:
headers = {"originator": "moviepilot"}

View File

@@ -7,12 +7,14 @@
"""
import json
import os
import re
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
from pathlib import Path
from typing import Annotated, Any, NotRequired, Optional, TypedDict
import anyio
from anyio import Path as AsyncPath
from langchain.agents.middleware.types import (
AgentMiddleware,
@@ -579,14 +581,29 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
entry = f"- **{now_str}** {summary}\n"
try:
if await log_path.exists():
existing = await log_path.read_text(encoding="utf-8", errors="replace")
await log_path.write_text(existing + entry, encoding="utf-8")
async with await anyio.open_file(
log_path,
mode="a",
encoding="utf-8",
) as stream:
await stream.write(entry)
else:
header = f"# {today_str} 活动日志\n\n"
await log_path.write_text(header + entry, encoding="utf-8")
logger.debug("Activity logged: %s", summary[:80])
try:
fd = os.open(log_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
except FileExistsError:
async with await anyio.open_file(
log_path,
mode="a",
encoding="utf-8",
) as stream:
await stream.write(entry)
else:
with os.fdopen(fd, "w", encoding="utf-8") as stream:
stream.write(header + entry)
logger.debug(f"Activity logged: {summary[:80]}")
except Exception as e:
logger.warning("Failed to append activity log: %s", e)
logger.warning(f"Failed to append activity log: {e}")
async def _cleanup_old_logs(self) -> None:
"""清理超过保留天数的旧日志文件。"""
@@ -608,20 +625,16 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
file_date = datetime.strptime(match.group(1), "%Y-%m-%d").date()
if file_date < cutoff_date:
await path.unlink()
logger.debug("Cleaned up old activity log: %s", path.name)
logger.debug(f"Cleaned up old activity log: {path.name}")
except ValueError:
continue
except Exception as e:
logger.warning("Failed to cleanup old activity logs: %s", e)
logger.warning(f"Failed to cleanup old activity logs: {e}")
async def abefore_agent(
self, state: ActivityLogState, runtime: Runtime
) -> Optional[ActivityLogStateUpdate]:
"""在 Agent 执行前加载近期活动日志。"""
# 如果已经加载则跳过
if "activity_log_contents" in state:
return None
contents = await self._load_recent_logs()
# 趁机清理旧日志(低频操作,不影响性能)
@@ -709,7 +722,7 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
if summary:
await self._append_activity(summary)
except Exception as e:
logger.warning("Failed to record activity: %s", e)
logger.warning(f"Failed to record activity: {e}")
return None

View File

@@ -283,12 +283,7 @@ class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa
) -> JobsStateUpdate | None:
"""在 Agent 执行前异步加载任务元数据。
每个会话仅加载一次。若 state 中已有则跳过。
"""
# 如果 state 中已存在元数据则跳过
if "jobs_metadata" in state:
return None
return JobsStateUpdate(
jobs_metadata=await load_jobs_metadata(self.sources)
)

View File

@@ -302,7 +302,6 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
"""在代理执行前扫描记忆目录并加载所有 .md 文件的内容。
自动发现目录下所有 `.md` 文件并加载其内容到状态中。
如果状态中尚未存在则进行加载。
同时检测记忆文件是否为空,设置 memory_empty 标志位,
以便在系统提示词中触发初始化引导流程。
@@ -314,10 +313,6 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
返回:
填充了 memory_contents 和 memory_empty 的状态更新。
"""
# 如果已经加载则跳过
if "memory_contents" in state:
return None
# 扫描目录下所有 .md 文件
md_files = await self._scan_memory_files()

View File

@@ -322,7 +322,7 @@ def _extract_version(skill_md: Path) -> int:
try:
content = skill_md.read_text(encoding="utf-8", errors="replace")
except Exception as err:
print(err)
logger.debug(f"读取技能版本失败: {err}")
return 0
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if not match:
@@ -627,13 +627,8 @@ class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # no
) -> SkillsStateUpdate | None: # ty: ignore[invalid-method-override]
"""在 Agent 执行前异步加载技能元数据。
每个会话仅加载一次。若 state 中已有则跳过。
首次加载时,会先将内置技能同步到用户目录(如不存在)。
"""
# 如果 state 中已存在元数据则跳过
if "skills_metadata" in state:
return None
self._sync_bundled_skills()
all_skills: dict[str, SkillMetadata] = {}

View File

@@ -197,18 +197,44 @@ def is_subagent_stream_metadata(metadata: Any) -> bool:
) == SUBAGENT_STREAM_MARKER_VALUE:
return True
return bool(metadata.get("lc_agent_name") in builtin_subagent_names())
return bool(
metadata.get("lc_agent_name")
in builtin_subagent_names(agent_runtime_manager.current_signature())
)
@lru_cache(maxsize=1)
def builtin_subagent_names() -> frozenset[str]:
def builtin_subagent_names(
runtime_signature: Optional[tuple[tuple[str, int, int], ...]] = None,
) -> frozenset[str]:
"""返回内置子代理名称集合。"""
return frozenset(profile.name for profile in _builtin_subagent_profiles())
runtime_signature = runtime_signature or agent_runtime_manager.current_signature()
return _cached_builtin_subagent_names(runtime_signature)
@lru_cache(maxsize=1)
def _builtin_subagent_profiles() -> tuple[_SubAgentProfile, ...]:
@lru_cache(maxsize=8)
def _cached_builtin_subagent_names(
runtime_signature: tuple[tuple[str, int, int], ...],
) -> frozenset[str]:
"""按运行时签名缓存内置子代理名称集合。"""
return frozenset(
profile.name
for profile in _builtin_subagent_profiles(runtime_signature)
)
def _builtin_subagent_profiles(
runtime_signature: Optional[tuple[tuple[str, int, int], ...]] = None,
) -> tuple[_SubAgentProfile, ...]:
"""从运行时配置目录加载 MoviePilot 子代理定义。"""
runtime_signature = runtime_signature or agent_runtime_manager.current_signature()
return _cached_builtin_subagent_profiles(runtime_signature)
@lru_cache(maxsize=8)
def _cached_builtin_subagent_profiles(
runtime_signature: tuple[tuple[str, int, int], ...],
) -> tuple[_SubAgentProfile, ...]:
"""按运行时签名缓存 MoviePilot 子代理定义。"""
definitions = agent_runtime_manager.list_subagents()
profiles = tuple(
_profile_from_runtime_definition(definition)
@@ -237,6 +263,10 @@ def _builtin_subagent_profiles() -> tuple[_SubAgentProfile, ...]:
)
builtin_subagent_names.cache_clear = _cached_builtin_subagent_names.cache_clear
_builtin_subagent_profiles.cache_clear = _cached_builtin_subagent_profiles.cache_clear
def _profile_from_runtime_definition(
definition: SubAgentDefinition,
) -> _SubAgentProfile:
@@ -1044,6 +1074,7 @@ class SubAgentTaskControlMiddleware(AgentMiddleware):
if unfinished_records:
logger.info(f"Agent 结束,取消未完成子代理任务: tasks={len(unfinished_records)}")
await self._cancel_records(unfinished_records)
self._tasks.clear()
async def awrap_tool_call(
self,
@@ -1083,9 +1114,8 @@ def create_subagent_middlewares(
stream_handler: Any = None,
) -> tuple[list[AgentMiddleware], list[BaseTool]]:
"""创建子代理中间件列表和任务工具列表。"""
_builtin_subagent_profiles.cache_clear()
builtin_subagent_names.cache_clear()
profiles = _builtin_subagent_profiles()
runtime_signature = agent_runtime_manager.current_signature()
profiles = _builtin_subagent_profiles(runtime_signature)
subagent_middleware = MoviePilotSubAgentMiddleware(
model=model,
profiles=profiles,

View File

@@ -592,22 +592,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果,
不会为每次模型回合重复追加一笔 selector LLM 开销。
"""
if "selected_tool_names" in state:
self._log_selection_attempt(
_ToolSelectionAttempt(
request=ModelRequest(
model=self.model,
tools=list(self.selection_tools),
messages=state["messages"],
state=state,
runtime=runtime,
),
selected_tool_names=state.get("selected_tool_names") or [],
status="reused",
)
)
return None
if not self.selection_tools or self.model is None:
detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型"
self._log_selection_attempt(

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import re
import shutil
import threading
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Iterable, Optional
@@ -243,9 +244,15 @@ class AgentRuntimeManager:
self._cache_lock = threading.Lock()
self._cached_signature: Optional[tuple[tuple[str, int, int], ...]] = None
self._cached_config: Optional[AgentRuntimeConfig] = None
self._cached_signature_checked_at = 0.0
self._signature_check_interval = 1.0
self._layout_ready = False
def ensure_layout(self) -> None:
"""创建目录、同步默认文件,并清理废弃的旧版 runtime 文件。"""
with self._cache_lock:
if self._layout_ready:
return
self.agent_root_dir.mkdir(parents=True, exist_ok=True)
self.runtime_dir.mkdir(parents=True, exist_ok=True)
self.memory_dir.mkdir(parents=True, exist_ok=True)
@@ -257,11 +264,13 @@ class AgentRuntimeManager:
self._remove_obsolete_runtime_files()
self._sync_bundled_defaults()
self._migrate_root_memory_files()
with self._cache_lock:
self._layout_ready = True
def load_runtime_config(self) -> AgentRuntimeConfig:
"""加载配置。用户目录损坏时自动回退到内置默认配置。"""
self.ensure_layout()
signature = self._build_signature()
signature = self.current_signature()
with self._cache_lock:
if self._cached_signature == signature and self._cached_config:
return self._cached_config
@@ -269,7 +278,7 @@ class AgentRuntimeManager:
try:
config = self._load_from_root(self.runtime_dir)
except AgentRuntimeConfigError as err:
logger.warning("Agent 根层配置无效,回退到内置默认配置: %s", err)
logger.warning(f"Agent 根层配置无效,回退到内置默认配置: {err}")
config = self._load_from_root(self.bundled_defaults_dir)
config.used_fallback = True
config.warnings.insert(
@@ -285,6 +294,25 @@ class AgentRuntimeManager:
with self._cache_lock:
self._cached_signature = None
self._cached_config = None
self._cached_signature_checked_at = 0.0
self._layout_ready = False
def current_signature(self) -> tuple[tuple[str, int, int], ...]:
"""返回当前运行时配置文件签名,供调用方判断缓存是否仍可复用。"""
now = time.monotonic()
with self._cache_lock:
if (
self._cached_signature is not None
and now - self._cached_signature_checked_at
< self._signature_check_interval
):
return self._cached_signature
signature = self._build_signature()
with self._cache_lock:
self._cached_signature = signature
self._cached_signature_checked_at = now
return signature
def set_active_persona(self, persona_query: str) -> AgentRuntimeConfig:
"""切换当前激活人格,并立即刷新缓存。"""
@@ -308,7 +336,7 @@ class AgentRuntimeManager:
)
current_path.write_text(document, encoding="utf-8")
self.invalidate_cache()
logger.info("已切换 Agent 人格: %s", persona.persona_id)
logger.info(f"已切换 Agent 人格: {persona.persona_id}")
return self.load_runtime_config()
def list_personas(self) -> list[PersonaDefinition]:
@@ -439,7 +467,7 @@ class AgentRuntimeManager:
continue
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, target)
logger.info("已同步默认 Agent 运行时文件: %s", target)
logger.info(f"已同步默认 Agent 运行时文件: {target}")
@classmethod
def _should_update_bundled_subagent(
@@ -478,7 +506,7 @@ class AgentRuntimeManager:
return
target.parent.mkdir(parents=True, exist_ok=True)
source.rename(target)
logger.info("已迁移旧版 Agent 根配置文件: %s -> %s", source, target)
logger.info(f"已迁移旧版 Agent 根配置文件: {source} -> {target}")
def _remove_obsolete_runtime_files(self) -> None:
"""删除不再支持的旧版 Agent 配置文件,避免被误迁移到 memory。"""
@@ -487,14 +515,14 @@ class AgentRuntimeManager:
if not path.exists() or not path.is_file():
continue
path.unlink()
logger.info("已删除废弃的 Agent 根配置文件: %s", path)
logger.info(f"已删除废弃的 Agent 根配置文件: {path}")
for relative_path in sorted(OBSOLETE_RUNTIME_FILES):
path = self.runtime_dir / relative_path
if not path.exists() or not path.is_file():
continue
path.unlink()
logger.info("已删除废弃的 Agent 运行时文件: %s", path)
logger.info(f"已删除废弃的 Agent 运行时文件: {path}")
def _migrate_root_memory_files(self) -> None:
"""将旧版根目录 memory 文件移入 `config/agent/memory`。"""
@@ -505,7 +533,7 @@ class AgentRuntimeManager:
if target.exists():
continue
path.rename(target)
logger.info("已迁移旧版 Agent memory 文件: %s -> %s", path, target)
logger.info(f"已迁移旧版 Agent memory 文件: {path} -> {target}")
def _load_from_root(self, root: Path) -> AgentRuntimeConfig:
current_persona_path = root / CURRENT_PERSONA_FILE

View File

@@ -431,7 +431,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
:return: 普通用户允许读写的本地目录列表
"""
roots = [
settings.CONFIG_PATH / "agent"
settings.CONFIG_PATH / "agent",
settings.LOG_PATH,
]
resolved_roots = []
for root in roots:
@@ -467,7 +468,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
allowed_text = "".join(str(root) for root in allowed_roots)
return (
resolved_path,
f"抱歉,普通用户只能{operation}配置目录、Agent记忆目录和日志目录内的文件或目录:{allowed_text}",
f"抱歉,普通用户只能{operation}Agent配置目录和日志目录内的文件或目录:{allowed_text}",
)
async def _check_local_storage_access(

View File

@@ -0,0 +1,88 @@
"""Agent 命令工具的安全校验逻辑。"""
from __future__ import annotations
import os.path
import re
import shlex
COMMAND_FORBIDDEN_KEYWORDS = (
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
)
COMMAND_DANGEROUS_PATTERNS = (
re.compile(r"\brm\s+[^;&|]*-[^\s;&|]*[rR][fF]?[^\s;&|]*\s+/(?:\s|$|[;&|])"),
re.compile(r"\bdd\s+[^;&|]*(?:of=/dev/(?:sd[a-z]\d*|nvme\d+n\d+p?\d*|disk\d+)|if=/dev/zero)"),
re.compile(r"\b(?:mkfs|fdisk|parted|diskutil)\b"),
re.compile(r"\b(?:chmod|chown)\s+[^;&|]*-R[^;&|]*\s+/(?:\s|$|[;&|])"),
re.compile(r"\b(?:reboot|shutdown|halt|poweroff)\b"),
)
def _command_tokens(command: str) -> list[str]:
"""尽力解析 shell 命令 token解析失败时退回空白分割。"""
try:
return shlex.split(command, posix=True)
except ValueError:
return re.split(r"\s+", command.strip())
def _contains_recursive_root_delete(command: str) -> bool:
"""识别递归删除根目录或一级目录的 rm 命令。"""
tokens = _command_tokens(command)
if not any(token == "rm" or token.endswith("/rm") for token in tokens):
return False
has_recursive = any(
token.startswith("-") and ("r" in token or "R" in token)
for token in tokens
)
if not has_recursive:
return False
for token in tokens:
clean_token = re.match(r"^([^;|&><]+)", token)
if not clean_token:
continue
path_value = clean_token.group(1).strip("\"'")
if not path_value.startswith("/"):
continue
norm_path = os.path.normpath(path_value)
if norm_path == "/" or re.match(r"^/[^/]+$", norm_path):
return True
return False
def detect_dangerous_command(command: str) -> str:
"""返回危险命令原因,安全时返回空字符串。"""
normalized = str(command or "").strip()
if not normalized:
return "命令不能为空"
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
if keyword in normalized:
return f"命令包含禁止使用的关键字 '{keyword}'"
if _contains_recursive_root_delete(normalized):
return "命令疑似递归删除根目录或一级目录"
for pattern in COMMAND_DANGEROUS_PATTERNS:
if pattern.search(normalized):
return "命令匹配高危系统操作模式"
return ""
def validate_command_safety(command: str, *, confirmed: bool = False) -> None:
"""
校验 shell 命令安全性。
:param command: 待执行命令
:param confirmed: 是否已经通过显式参数确认高危操作
"""
reason = detect_dangerous_command(command)
if not reason:
return
if confirmed and reason != "命令不能为空":
return
raise ValueError(f"{reason}。如确认需要执行,请设置 confirm_dangerous=true")

View File

@@ -5,8 +5,7 @@ import re
from typing import Any, Dict, Iterable, Optional
from app.core.event import eventmanager
from app.db import AsyncSessionFactory
from app.db.models.subscribe import Subscribe
from app.db.subscribe_oper import SubscribeOper
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.rule import RuleHelper
from app.modules.filter.RuleParser import RuleParser
@@ -284,23 +283,22 @@ async def collect_rule_group_usages(
continue
ensure_usage(name)["used_in_global_best_version"] = True
async with AsyncSessionFactory() as db:
subscribes = await Subscribe.async_list(db)
for subscribe in subscribes:
filter_groups = subscribe.filter_groups or []
for name in filter_groups:
if target_names and name not in target_names:
continue
ensure_usage(name)["subscribes"].append(
{
"subscribe_id": subscribe.id,
"name": subscribe.name,
"season": subscribe.season,
"type": subscribe.type,
"username": subscribe.username,
"best_version": bool(subscribe.best_version),
}
)
subscribes = await SubscribeOper().async_list()
for subscribe in subscribes:
filter_groups = subscribe.filter_groups or []
for name in filter_groups:
if target_names and name not in target_names:
continue
ensure_usage(name)["subscribes"].append(
{
"subscribe_id": subscribe.id,
"name": subscribe.name,
"season": subscribe.season,
"type": subscribe.type,
"username": subscribe.username,
"best_version": bool(subscribe.best_version),
}
)
return usage_map
@@ -482,22 +480,22 @@ async def rename_rule_group_references(old_name: str, new_name: str) -> dict:
await save_system_config(config_key, updated)
changed["global_settings"][config_key.value] = updated
async with AsyncSessionFactory() as db:
subscribes = await Subscribe.async_list(db)
for subscribe in subscribes:
original = subscribe.filter_groups or []
updated = replace_group_name_in_list(original, old_name, new_name)
if updated == original:
continue
await subscribe.async_update(db, {"filter_groups": updated})
changed["subscribes"].append(
{
"subscribe_id": subscribe.id,
"name": subscribe.name,
"season": subscribe.season,
"filter_groups": updated,
}
)
subscribe_oper = SubscribeOper()
subscribes = await subscribe_oper.async_list()
for subscribe in subscribes:
original = subscribe.filter_groups or []
updated = replace_group_name_in_list(original, old_name, new_name)
if updated == original:
continue
await subscribe_oper.async_update_filter_groups(subscribe.id, updated)
changed["subscribes"].append(
{
"subscribe_id": subscribe.id,
"name": subscribe.name,
"season": subscribe.season,
"filter_groups": updated,
}
)
return changed
@@ -520,21 +518,21 @@ async def remove_rule_group_references(group_name: str) -> dict:
await save_system_config(config_key, updated)
changed["global_settings"][config_key.value] = updated
async with AsyncSessionFactory() as db:
subscribes = await Subscribe.async_list(db)
for subscribe in subscribes:
original = subscribe.filter_groups or []
updated = [value for value in original if value != group_name]
if updated == original:
continue
await subscribe.async_update(db, {"filter_groups": updated})
changed["subscribes"].append(
{
"subscribe_id": subscribe.id,
"name": subscribe.name,
"season": subscribe.season,
"filter_groups": updated,
}
)
subscribe_oper = SubscribeOper()
subscribes = await subscribe_oper.async_list()
for subscribe in subscribes:
original = subscribe.filter_groups or []
updated = [value for value in original if value != group_name]
if updated == original:
continue
await subscribe_oper.async_update_filter_groups(subscribe.id, updated)
changed["subscribes"].append(
{
"subscribe_id": subscribe.id,
"name": subscribe.name,
"season": subscribe.season,
"filter_groups": updated,
}
)
return changed

View File

@@ -1,7 +1,7 @@
"""系统设置工具共用的键解析与分组元数据。"""
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional
from app.core.config import Settings
from app.schemas.types import SystemConfigKey
@@ -15,6 +15,7 @@ class SettingSpec:
source: str
group: str
label: str
systemconfig_key: Optional[SystemConfigKey] = None
SYSTEMCONFIG_SETTING_METADATA = {
@@ -234,6 +235,7 @@ def _build_specs() -> tuple[dict[str, SettingSpec], dict[str, SettingSpec]]:
source="systemconfig",
group=metadata.get("group", "misc"),
label=metadata.get("label", item.value),
systemconfig_key=item,
)
return core_specs, system_specs
@@ -333,3 +335,57 @@ def list_setting_specs(
def get_default_list_match_field(setting_key: str) -> Optional[str]:
return LIST_ITEM_MATCH_FIELD_DEFAULTS.get(setting_key)
SECRET_KEYWORDS = (
"api_key",
"apikey",
"token",
"secret",
"password",
"passwd",
"cookie",
"authorization",
"refresh_token",
"access_token",
)
def is_secret_setting_key(key: str) -> bool:
"""判断设置键名是否疑似敏感字段。"""
normalized = _normalize_token(key)
return any(keyword in normalized for keyword in SECRET_KEYWORDS)
def redact_secret_value(value: Any, *, redact_scalar: bool = False) -> Any:
"""递归脱敏配置值中的密钥、Cookie、Token 等敏感字段。"""
if isinstance(value, dict):
return {
key: "***"
if is_secret_setting_key(str(key))
else redact_secret_value(item, redact_scalar=redact_scalar)
for key, item in value.items()
}
if isinstance(value, list):
return [
redact_secret_value(item, redact_scalar=redact_scalar)
for item in value
]
if isinstance(value, str):
return "***" if value and redact_scalar else value
return value
def should_redact_setting(spec: SettingSpec, value: Any) -> bool:
"""判断某项设置在默认查询响应中是否需要脱敏。"""
if is_secret_setting_key(spec.key):
return True
if isinstance(value, dict):
return any(is_secret_setting_key(str(key)) for key in value.keys())
if isinstance(value, list):
return any(
should_redact_setting(spec, item)
for item in value
if isinstance(item, dict)
)
return False

View File

@@ -13,6 +13,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
from app.agent.tools.impl._command_safety import validate_command_safety
from app.core.config import settings
from app.log import logger
@@ -34,14 +35,6 @@ TERMINAL_PTY_POLL_INTERVAL = 0.05
TERMINAL_WAIT_DEFAULT_MS = 1000
TERMINAL_WAIT_MAX_MS = 60 * 1000
TERMINAL_KILL_GRACE_SECONDS = 3
TERMINAL_FORBIDDEN_KEYWORDS = (
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
)
@dataclass
@@ -176,13 +169,9 @@ class _TerminalSessionManager:
return merged_env
@staticmethod
def _validate_command(command: str) -> None:
def _validate_command(command: str, *, confirmed: bool = False) -> None:
"""拒绝明显危险或空白命令。"""
if not command or not command.strip():
raise ValueError("命令不能为空")
for keyword in TERMINAL_FORBIDDEN_KEYWORDS:
if keyword in command:
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
validate_command_safety(command, confirmed=confirmed)
@staticmethod
def _set_nonblocking(fd: int) -> None:
@@ -213,9 +202,10 @@ class _TerminalSessionManager:
cwd: Optional[str] = None,
env: Optional[dict[str, Any]] = None,
use_pty: Any = True,
confirm_dangerous: bool = False,
) -> dict[str, Any]:
"""启动后台命令并立即返回会话 ID。"""
self._validate_command(command)
self._validate_command(command, confirmed=confirm_dangerous)
normalized_cwd = self._normalize_cwd(cwd)
normalized_env = self._build_env(env)
should_use_pty = self._normalize_bool(use_pty, default=True) and os.name == "posix"
@@ -313,7 +303,10 @@ class _TerminalSessionManager:
continue
except OSError as err:
if err.errno not in {errno.EIO, errno.EBADF}:
logger.debug("PTY 输出读取异常: session_id=%s, error=%s", session.session_id, err)
logger.debug(
f"PTY 输出读取异常: session_id={session.session_id}, "
f"error={err}"
)
break
if not data:
@@ -343,7 +336,9 @@ class _TerminalSessionManager:
session.mark_finished(session.exit_code)
except Exception as err:
session.mark_error(str(err))
logger.warning("等待 PTY 进程失败: session_id=%s, error=%s", session.session_id, err)
logger.warning(
f"等待 PTY 进程失败: session_id={session.session_id}, error={err}"
)
finally:
await self._finish_reader_tasks(session)
session.close_pty()
@@ -358,7 +353,9 @@ class _TerminalSessionManager:
session.mark_finished(exit_code)
except Exception as err:
session.mark_error(str(err))
logger.warning("等待管道进程失败: session_id=%s, error=%s", session.session_id, err)
logger.warning(
f"等待管道进程失败: session_id={session.session_id}, error={err}"
)
finally:
await self._finish_reader_tasks(session)

View File

@@ -228,6 +228,11 @@ class BrowseWebpageTool(MoviePilotTool):
return "错误: 'fill_ref' 操作需要提供 value 参数"
if browser_action == BrowserAction.EVALUATE and not script:
return "错误: 'evaluate' 操作需要提供 script 参数"
if (
browser_action == BrowserAction.EVALUATE
and not await self.is_admin_user()
):
return "错误: 'evaluate' 操作仅允许管理员使用"
if (
browser_action in (BrowserAction.FOCUS_TAB, BrowserAction.CLOSE_TAB)
and tab_index is None

View File

@@ -6,8 +6,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.db import AsyncSessionFactory
from app.db.models.downloadhistory import DownloadHistory
from app.db.downloadhistory_oper import DownloadHistoryOper
from app.log import logger
@@ -40,9 +39,8 @@ class DeleteDownloadHistoryTool(MoviePilotTool):
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
try:
async with AsyncSessionFactory() as db:
await DownloadHistory.async_delete(db, history_id)
return f"下载历史记录 ID: {history_id} 已成功删除"
await DownloadHistoryOper().async_delete_history(history_id)
return f"下载历史记录 ID: {history_id} 已成功删除"
except Exception as e:
logger.error(f"删除下载历史记录失败: {e}", exc_info=True)
return f"删除下载历史记录时发生错误: {str(e)}"

View File

@@ -12,7 +12,7 @@ from app.log import logger
class EditFileInput(BaseModel):
"""Input parameters for edit file tool"""
"""文件编辑工具的输入参数模型。"""
file_path: str = Field(..., description="The absolute path of the file to edit")
old_text: str = Field(..., description="The exact old text to be replaced")
@@ -27,8 +27,8 @@ class EditFileTool(MoviePilotTool):
]
description: str = (
"Edit a local text file by replacing specific old text with new text. "
"Non-admin users can only edit files inside the MoviePilot config, "
"Agent memory/activity, and log directories."
"Non-admin users can only edit files inside the MoviePilot Agent config "
"and log directories."
)
args_schema: Type[BaseModel] = EditFileInput

View File

@@ -13,6 +13,7 @@ from typing import Any, Literal, Optional, TextIO, Type
from pydantic import BaseModel, Field
from app.agent.tools.impl._command_safety import validate_command_safety
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.agent.tools.impl._terminal_session import (
@@ -30,14 +31,6 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
READ_CHUNK_SIZE = 4096
KILL_GRACE_SECONDS = 3
COMMAND_CONCURRENCY_LIMIT = 2
COMMAND_FORBIDDEN_KEYWORDS = (
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
)
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
@@ -195,6 +188,13 @@ class ExecuteCommandInput(BaseModel):
60,
description="For action=run, max execution time in seconds.",
)
confirm_dangerous: Optional[bool] = Field(
False,
description=(
"Explicit confirmation for high-risk commands such as recursive root deletion, "
"disk formatting, shutdown/reboot, or destructive permission changes."
),
)
class ExecuteCommandTool(MoviePilotTool):
@@ -255,34 +255,9 @@ class ExecuteCommandTool(MoviePilotTool):
return command
@staticmethod
def _validate_command(command: str) -> None:
def _validate_command(command: str, *, confirmed: bool = False) -> None:
"""复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。"""
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
if keyword in command:
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
# 检查是否使用了 rm -r/R 删除根目录或一级目录,防止误杀多级目录
import re
import os.path
tokens = re.split(r'\s+', command.strip())
if any(t == "rm" or t.endswith("/rm") for t in tokens):
has_r = False
for token in tokens:
if token.startswith("-") and ("r" in token or "R" in token):
has_r = True
break
if has_r:
for token in tokens:
# 提取可能包含目标路径的部分(去除重定向、管道、分号等末尾干扰)
m = re.match(r'^([^;\|&><]+)', token)
if m:
clean_token = m.group(1).strip('"\'')
# 仅对绝对路径进行一级目录限制
if clean_token.startswith('/'):
norm_path = os.path.normpath(clean_token)
if re.match(r'^/[^/]*$', norm_path) or re.match(r'^/[^/]*/$', norm_path):
raise ValueError(f"不允许使用 rm 命令删除根目录或一级目录: {clean_token}")
validate_command_safety(command, confirmed=confirmed)
@staticmethod
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
@@ -367,7 +342,7 @@ class ExecuteCommandTool(MoviePilotTool):
asyncio.shield(wait_task), timeout=KILL_GRACE_SECONDS
)
except asyncio.TimeoutError:
logger.warning("命令进程强制清理超时: pid=%s", process.pid)
logger.warning(f"命令进程强制清理超时: pid={process.pid}")
@staticmethod
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
@@ -382,7 +357,7 @@ class ExecuteCommandTool(MoviePilotTool):
if isinstance(result, Exception) and not isinstance(
result, asyncio.CancelledError
):
logger.debug("命令输出读取任务异常: %s", result)
logger.debug(f"命令输出读取任务异常: {result}")
@staticmethod
def _format_run_result(
@@ -425,9 +400,10 @@ class ExecuteCommandTool(MoviePilotTool):
command: str,
timeout: Optional[int],
cwd: Optional[str] = None,
confirm_dangerous: bool = False,
) -> str:
"""按旧模式一次性执行命令,等待完成或超时后返回文本结果。"""
self._validate_command(command)
self._validate_command(command, confirmed=confirm_dangerous)
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
async with _command_semaphore:
@@ -482,27 +458,29 @@ class ExecuteCommandTool(MoviePilotTool):
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
timeout: Optional[int] = 60,
confirm_dangerous: Optional[bool] = False,
**kwargs,
) -> str:
"""执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。"""
normalized_action = (action or "start").strip().lower()
logger.info(
"执行工具: %s, action=%s, command=%s, session_id=%s",
self.name,
normalized_action,
command,
session_id,
f"执行工具: {self.name}, action={normalized_action}, "
f"command={command}, session_id={session_id}"
)
try:
if normalized_action == "start":
start_command = self._require_command(command)
self._validate_command(start_command)
self._validate_command(
start_command,
confirmed=bool(confirm_dangerous),
)
payload = await terminal_session_manager.start(
command=start_command,
cwd=cwd,
env=env,
use_pty=use_pty,
confirm_dangerous=bool(confirm_dangerous),
)
return self._dump(payload)
@@ -542,9 +520,10 @@ class ExecuteCommandTool(MoviePilotTool):
command=self._require_command(command),
timeout=timeout,
cwd=cwd,
confirm_dangerous=bool(confirm_dangerous),
)
raise ValueError(f"不支持的 action: {action}")
except Exception as err:
logger.error("执行命令 action 失败: %s", err, exc_info=True)
logger.error(f"执行命令 action 失败: {err}", exc_info=True)
return self._dump({"error": str(err), "status": "error", "action": normalized_action})

View File

@@ -7,9 +7,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.db import AsyncSessionFactory
from app.db.models.site import Site
from app.db.models.siteuserdata import SiteUserData
from app.db.site_oper import SiteOper
from app.log import logger
SITE_USERDATA_DETAIL_PREVIEW_LIMIT = 10
@@ -66,118 +64,115 @@ class QuerySiteUserdataTool(MoviePilotTool):
)
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 获取站点
site = await Site.async_get(db, site_id)
if not site:
return json.dumps(
{"success": False, "message": f"站点不存在: {site_id}"},
ensure_ascii=False,
)
# 获取站点用户数据
user_data_list = await SiteUserData.async_get_by_domain(
db, domain=site.domain, workdate=workdate
site_oper = SiteOper()
site = await site_oper.async_get(site_id)
if not site:
return json.dumps(
{"success": False, "message": f"站点不存在: {site_id}"},
ensure_ascii=False,
)
if not user_data_list:
return json.dumps(
{
"success": False,
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
"site_id": site_id,
"site_name": site.name,
"site_domain": site.domain,
"workdate": workdate,
},
ensure_ascii=False,
)
user_data_list = await site_oper.async_get_userdata_by_domain(
domain=site.domain, workdate=workdate
)
# 格式化用户数据
result = {
"success": True,
"site_id": site_id,
"site_name": site.name,
"site_domain": site.domain,
"workdate": workdate,
"data_count": len(user_data_list),
"user_data": [],
if not user_data_list:
return json.dumps(
{
"success": False,
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
"site_id": site_id,
"site_name": site.name,
"site_domain": site.domain,
"workdate": workdate,
},
ensure_ascii=False,
)
# 格式化用户数据
result = {
"success": True,
"site_id": site_id,
"site_name": site.name,
"site_domain": site.domain,
"workdate": workdate,
"data_count": len(user_data_list),
"user_data": [],
}
for user_data in user_data_list:
# 格式化上传/下载量(转换为可读格式)
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
download_gb = (
user_data.download / (1024**3) if user_data.download else 0
)
seeding_size_gb = (
user_data.seeding_size / (1024**3)
if user_data.seeding_size
else 0
)
leeching_size_gb = (
user_data.leeching_size / (1024**3)
if user_data.leeching_size
else 0
)
seeding_preview, seeding_count, seeding_truncated = _preview_list(
user_data.seeding_info
)
unread_preview, unread_count, unread_truncated = _preview_list(
user_data.message_unread_contents
)
user_data_dict = {
"domain": user_data.domain,
"name": user_data.name,
"username": user_data.username,
"userid": user_data.userid,
"user_level": user_data.user_level,
"join_at": user_data.join_at,
"bonus": user_data.bonus,
"upload": user_data.upload,
"upload_gb": round(upload_gb, 2),
"download": user_data.download,
"download_gb": round(download_gb, 2),
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
"seeding": int(user_data.seeding) if user_data.seeding else 0,
"leeching": int(user_data.leeching)
if user_data.leeching
else 0,
"seeding_size": user_data.seeding_size,
"seeding_size_gb": round(seeding_size_gb, 2),
"leeching_size": user_data.leeching_size,
"leeching_size_gb": round(leeching_size_gb, 2),
"seeding_info_count": seeding_count,
"seeding_info": seeding_preview,
"seeding_info_truncated": seeding_truncated,
"message_unread": user_data.message_unread,
"message_unread_contents_count": unread_count,
"message_unread_contents": unread_preview,
"message_unread_contents_truncated": unread_truncated,
"err_msg": user_data.err_msg,
"updated_day": user_data.updated_day,
"updated_time": user_data.updated_time,
}
result["user_data"].append(user_data_dict)
for user_data in user_data_list:
# 格式化上传/下载量(转换为可读格式)
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
download_gb = (
user_data.download / (1024**3) if user_data.download else 0
)
seeding_size_gb = (
user_data.seeding_size / (1024**3)
if user_data.seeding_size
else 0
)
leeching_size_gb = (
user_data.leeching_size / (1024**3)
if user_data.leeching_size
else 0
)
# 如果有多条数据,只返回最新的(按更新时间排序)
if len(result["user_data"]) > 1:
result["user_data"].sort(
key=lambda x: (
x.get("updated_day", ""),
x.get("updated_time", ""),
),
reverse=True,
)
result["message"] = (
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
)
result["user_data"] = [result["user_data"][0]]
seeding_preview, seeding_count, seeding_truncated = _preview_list(
user_data.seeding_info
)
unread_preview, unread_count, unread_truncated = _preview_list(
user_data.message_unread_contents
)
user_data_dict = {
"domain": user_data.domain,
"name": user_data.name,
"username": user_data.username,
"userid": user_data.userid,
"user_level": user_data.user_level,
"join_at": user_data.join_at,
"bonus": user_data.bonus,
"upload": user_data.upload,
"upload_gb": round(upload_gb, 2),
"download": user_data.download,
"download_gb": round(download_gb, 2),
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
"seeding": int(user_data.seeding) if user_data.seeding else 0,
"leeching": int(user_data.leeching)
if user_data.leeching
else 0,
"seeding_size": user_data.seeding_size,
"seeding_size_gb": round(seeding_size_gb, 2),
"leeching_size": user_data.leeching_size,
"leeching_size_gb": round(leeching_size_gb, 2),
"seeding_info_count": seeding_count,
"seeding_info": seeding_preview,
"seeding_info_truncated": seeding_truncated,
"message_unread": user_data.message_unread,
"message_unread_contents_count": unread_count,
"message_unread_contents": unread_preview,
"message_unread_contents_truncated": unread_truncated,
"err_msg": user_data.err_msg,
"updated_day": user_data.updated_day,
"updated_time": user_data.updated_time,
}
result["user_data"].append(user_data_dict)
# 如果有多条数据,只返回最新的(按更新时间排序)
if len(result["user_data"]) > 1:
result["user_data"].sort(
key=lambda x: (
x.get("updated_day", ""),
x.get("updated_time", ""),
),
reverse=True,
)
result["message"] = (
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
)
result["user_data"] = [result["user_data"][0]]
return json.dumps(result, ensure_ascii=False, indent=2)
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"查询站点用户数据失败: {str(e)}"

View File

@@ -7,8 +7,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.db import AsyncSessionFactory
from app.db.models.subscribehistory import SubscribeHistory
from app.db.subscribehistory_oper import SubscribeHistoryOper
from app.log import logger
from app.schemas.types import media_type_to_agent
@@ -74,88 +73,87 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
if media_type not in ["all", "movie", "tv"]:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
# 获取数据库会话
async with AsyncSessionFactory() as db:
if name:
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
fetch_count = 500
if media_type == "all":
movie_history = await SubscribeHistory.async_list_by_type(
db, mtype="movie", page=1, count=fetch_count
)
tv_history = await SubscribeHistory.async_list_by_type(
db, mtype="tv", page=1, count=fetch_count
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
else:
all_history = list(
await SubscribeHistory.async_list_by_type(
db, mtype=media_type, page=1, count=fetch_count
)
)
# 按名称过滤
name_lower = name.lower()
filtered_history = [
record
for record in all_history
if record.name and name_lower in record.name.lower()
]
if not filtered_history:
return "未找到相关订阅历史记录"
# 名称过滤时直接返回所有匹配结果,不分页
simplified_records = self._simplify_records(filtered_history)
result_json = json.dumps(
simplified_records, ensure_ascii=False, indent=2
subscribe_history_oper = SubscribeHistoryOper()
if name:
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
fetch_count = 500
if media_type == "all":
movie_history = await subscribe_history_oper.async_list_by_type(
mtype="movie", page=1, count=fetch_count
)
return result_json
tv_history = await subscribe_history_oper.async_list_by_type(
mtype="tv", page=1, count=fetch_count
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
else:
# 无名称过滤时,直接利用数据库分页
if media_type == "all":
movie_history = await SubscribeHistory.async_list_by_type(
db, mtype="movie", page=1, count=page * PAGE_SIZE
)
tv_history = await SubscribeHistory.async_list_by_type(
db, mtype="tv", page=1, count=page * PAGE_SIZE
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
filtered_history = all_history
else:
filtered_history = list(
await SubscribeHistory.async_list_by_type(
db, mtype=media_type, page=1, count=page * PAGE_SIZE
)
all_history = list(
await subscribe_history_oper.async_list_by_type(
mtype=media_type, page=1, count=fetch_count
)
)
# 按名称过滤
name_lower = name.lower()
filtered_history = [
record
for record in all_history
if record.name and name_lower in record.name.lower()
]
if not filtered_history:
return "未找到相关订阅历史记录"
# 分页切片
total_count = len(filtered_history)
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_records = filtered_history[start:end]
if not page_records:
return f"{page} 页没有数据。"
simplified_records = self._simplify_records(page_records)
# 名称过滤时直接返回所有匹配结果,不分页
simplified_records = self._simplify_records(filtered_history)
result_json = json.dumps(
simplified_records, ensure_ascii=False, indent=2
)
has_more = total_count > end
payload_msg = f"{page} 页,当前页 {len(simplified_records)} 条结果。"
if has_more:
payload_msg += (
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
return result_json
else:
# 无名称过滤时,直接利用数据库分页
if media_type == "all":
movie_history = await subscribe_history_oper.async_list_by_type(
mtype="movie", page=1, count=page * PAGE_SIZE
)
tv_history = await subscribe_history_oper.async_list_by_type(
mtype="tv", page=1, count=page * PAGE_SIZE
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
filtered_history = all_history
else:
filtered_history = list(
await subscribe_history_oper.async_list_by_type(
mtype=media_type, page=1, count=page * PAGE_SIZE
)
)
return f"{payload_msg}\n\n{result_json}"
if not filtered_history:
return "未找到相关订阅历史记录"
# 分页切片
total_count = len(filtered_history)
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_records = filtered_history[start:end]
if not page_records:
return f"{page} 页没有数据。"
simplified_records = self._simplify_records(page_records)
result_json = json.dumps(
simplified_records, ensure_ascii=False, indent=2
)
has_more = total_count > end
payload_msg = f"{page} 页,当前页 {len(simplified_records)} 条结果。"
if has_more:
payload_msg += (
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
)
return f"{payload_msg}\n\n{result_json}"
except Exception as e:
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
return f"查询订阅历史时发生错误: {str(e)}"

View File

@@ -9,8 +9,11 @@ from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.agent.tools.impl._system_setting_utils import (
SettingSpec,
is_secret_setting_key,
list_setting_specs,
redact_secret_value,
resolve_setting_spec,
should_redact_setting,
)
from app.core.config import settings
from app.db.systemconfig_oper import SystemConfigOper
@@ -53,6 +56,13 @@ class QuerySystemSettingsInput(BaseModel):
"when multiple settings are matched it returns summaries only unless this is explicitly set to true."
),
)
show_secrets: Optional[bool] = Field(
False,
description=(
"Whether to return raw secret values such as API keys, tokens, cookies, and passwords. "
"Defaults to false; secret-like fields are redacted in returned values and previews."
),
)
class QuerySystemSettingsTool(MoviePilotTool):
@@ -85,15 +95,18 @@ class QuerySystemSettingsTool(MoviePilotTool):
@staticmethod
def _load_setting_value(spec: SettingSpec):
"""读取指定设置项的当前值。"""
if spec.source == "settings":
return getattr(settings, spec.key)
return SystemConfigOper().get(spec.key)
return SystemConfigOper().get(spec.systemconfig_key)
@staticmethod
def _summarize_value(value) -> dict:
def _summarize_value(value, *, redacted: bool = False) -> dict:
"""生成设置值摘要,避免列表和字典默认输出过长。"""
summary = {
"has_value": value is not None,
"value_type": type(value).__name__,
"redacted": redacted,
}
if isinstance(value, list):
summary["item_count"] = len(value)
@@ -122,14 +135,12 @@ class QuerySystemSettingsTool(MoviePilotTool):
group: Optional[str] = "all",
keyword: Optional[str] = None,
include_values: Optional[bool] = None,
show_secrets: Optional[bool] = False,
**kwargs,
) -> str:
logger.info(
"执行工具: %s, setting_key=%s, group=%s, keyword=%s",
self.name,
setting_key,
group,
keyword,
f"执行工具: {self.name}, setting_key={setting_key}, "
f"group={group}, keyword={keyword}"
)
try:
@@ -158,18 +169,30 @@ class QuerySystemSettingsTool(MoviePilotTool):
should_include_values = (
include_values if include_values is not None else len(specs) == 1
)
allow_secret_values = bool(show_secrets) and await self.is_admin_user()
settings_payload = []
for spec in specs:
value = self._load_setting_value(spec)
should_redact = (
should_redact_setting(spec, value) and not allow_secret_values
)
response_value = (
redact_secret_value(
value,
redact_scalar=is_secret_setting_key(spec.key),
)
if should_redact
else value
)
item = {
"setting_key": spec.key,
"source": spec.source,
"group": spec.group,
"label": spec.label,
}
item.update(self._summarize_value(value))
item.update(self._summarize_value(response_value, redacted=should_redact))
if should_include_values:
item["value"] = value
item["value"] = response_value
settings_payload.append(item)
return json.dumps(
@@ -177,6 +200,7 @@ class QuerySystemSettingsTool(MoviePilotTool):
"success": True,
"matched_count": len(settings_payload),
"include_values": should_include_values,
"show_secrets": allow_secret_values,
"settings": settings_payload,
},
ensure_ascii=False,

View File

@@ -7,8 +7,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.db import AsyncSessionFactory
from app.db.models.transferhistory import TransferHistory
from app.db.transferhistory_oper import TransferHistoryOper
from app.log import logger
from app.schemas.types import media_type_to_agent
from app.utils.jieba import cut as jieba_cut
@@ -70,70 +69,69 @@ class QueryTransferHistoryTool(MoviePilotTool):
# 每页固定 30 条,与工具说明保持一致,避免整理路径等字段撑大上下文。
count = 30
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 处理标题搜索
if title:
# 使用统一分词封装处理标题,便于替换底层实现。
words = jieba_cut(title, HMM=False)
title_search = "%".join(words)
# 查询记录
result = await TransferHistory.async_list_by_title(
db, title=title_search, page=page, count=count, status=status_bool
)
total = await TransferHistory.async_count_by_title(
db, title=title_search, status=status_bool
)
else:
# 查询所有记录
result = await TransferHistory.async_list_by_page(
db, page=page, count=count, status=status_bool
)
total = await TransferHistory.async_count(db, status=status_bool)
transferhis = TransferHistoryOper()
# 处理标题搜索
if title:
# 使用统一分词封装处理标题,便于替换底层实现。
words = jieba_cut(title, HMM=False)
title_search = "%".join(words)
# 查询记录
result = await transferhis.async_list_by_title(
title=title_search, page=page, count=count, status=status_bool
)
total = await transferhis.async_count_by_title(
title=title_search, status=status_bool
)
else:
# 查询所有记录
result = await transferhis.async_list_by_page(
page=page, count=count, status=status_bool
)
total = await transferhis.async_count(status=status_bool)
if not result:
return "未找到相关整理历史记录"
if not result:
return "未找到相关整理历史记录"
# 转换为字典格式,只保留关键信息
simplified_records = []
for record in result:
simplified = {
"id": record.id,
"title": record.title,
"year": record.year,
"type": media_type_to_agent(record.type),
"category": record.category,
"seasons": record.seasons,
"episodes": record.episodes,
"src": record.src,
"dest": record.dest,
"mode": record.mode,
"status": "成功" if record.status else "失败",
"date": record.date,
"downloader": record.downloader,
"download_hash": record.download_hash
}
# 如果失败,添加错误信息
if not record.status and record.errmsg:
simplified["errmsg"] = record.errmsg
# 添加媒体ID信息如果有
if record.tmdbid:
simplified["tmdbid"] = record.tmdbid
if record.imdbid:
simplified["imdbid"] = record.imdbid
if record.doubanid:
simplified["doubanid"] = record.doubanid
simplified_records.append(simplified)
# 转换为字典格式,只保留关键信息
simplified_records = []
for record in result:
simplified = {
"id": record.id,
"title": record.title,
"year": record.year,
"type": media_type_to_agent(record.type),
"category": record.category,
"seasons": record.seasons,
"episodes": record.episodes,
"src": record.src,
"dest": record.dest,
"mode": record.mode,
"status": "成功" if record.status else "失败",
"date": record.date,
"downloader": record.downloader,
"download_hash": record.download_hash
}
# 如果失败,添加错误信息
if not record.status and record.errmsg:
simplified["errmsg"] = record.errmsg
# 添加媒体ID信息如果有
if record.tmdbid:
simplified["tmdbid"] = record.tmdbid
if record.imdbid:
simplified["imdbid"] = record.imdbid
if record.doubanid:
simplified["doubanid"] = record.doubanid
simplified_records.append(simplified)
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
# 计算总页数
total_pages = (total + count - 1) // count if total > 0 else 1
# 计算总页数
total_pages = (total + count - 1) // count if total > 0 else 1
# 构建分页信息
pagination_info = f"{page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)"
# 构建分页信息
pagination_info = f"{page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)"
return f"{pagination_info}\n\n{result_json}"
return f"{pagination_info}\n\n{result_json}"
except Exception as e:
logger.error(f"查询整理历史记录失败: {e}", exc_info=True)
return f"查询整理历史记录时发生错误: {str(e)}"

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.db import AsyncSessionFactory
from app.db.workflow_oper import WorkflowOper
from app.log import logger
@@ -56,75 +55,73 @@ class QueryWorkflowsTool(MoviePilotTool):
logger.info(f"执行工具: {self.name}, 参数: state={state}, name={name}, trigger_type={trigger_type}")
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
workflow_oper = WorkflowOper(db)
workflows = await workflow_oper.async_list()
# 过滤工作流
filtered_workflows = []
for wf in workflows:
# 按状态过滤
if state != "all" and wf.state != state:
workflow_oper = WorkflowOper()
workflows = await workflow_oper.async_list()
# 过滤工作流
filtered_workflows = []
for wf in workflows:
# 按状态过滤
if state != "all" and wf.state != state:
continue
# 按触发类型过滤
if trigger_type != "all":
if trigger_type == "timer" and wf.trigger_type not in ["timer", None]:
continue
# 按触发类型过滤
if trigger_type != "all":
if trigger_type == "timer" and wf.trigger_type not in ["timer", None]:
continue
elif trigger_type == "event" and wf.trigger_type != "event":
continue
elif trigger_type == "manual" and wf.trigger_type != "manual":
continue
# 按名称过滤(部分匹配)
if name and wf.name and name.lower() not in wf.name.lower():
elif trigger_type == "event" and wf.trigger_type != "event":
continue
filtered_workflows.append(wf)
if not filtered_workflows:
return "未找到相关工作流"
# 转换为字典格式,只保留关键信息
simplified_workflows = []
for wf in filtered_workflows:
# 状态说明
state_map = {
"W": "等待",
"R": "运行中",
"P": "暂停",
"S": "成功",
"F": "失败"
}
state_desc = state_map.get(wf.state, wf.state)
# 触发类型说明
trigger_type_map = {
"timer": "定时触发",
"event": "事件触发",
"manual": "手动触发"
}
trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发")
simplified = {
"id": wf.id,
"name": wf.name,
"description": wf.description,
"trigger_type": trigger_type_desc,
"state": state_desc,
"run_count": wf.run_count,
"timer": wf.timer,
"event_type": wf.event_type,
"add_time": wf.add_time,
"last_time": wf.last_time,
"current_action": wf.current_action
}
# wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。
simplified_workflows.append(simplified)
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)
return result_json
elif trigger_type == "manual" and wf.trigger_type != "manual":
continue
# 按名称过滤(部分匹配)
if name and wf.name and name.lower() not in wf.name.lower():
continue
filtered_workflows.append(wf)
if not filtered_workflows:
return "未找到相关工作流"
# 转换为字典格式,只保留关键信息
simplified_workflows = []
for wf in filtered_workflows:
# 状态说明
state_map = {
"W": "等待",
"R": "运行中",
"P": "暂停",
"S": "成功",
"F": "失败"
}
state_desc = state_map.get(wf.state, wf.state)
# 触发类型说明
trigger_type_map = {
"timer": "定时触发",
"event": "事件触发",
"manual": "手动触发"
}
trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发")
simplified = {
"id": wf.id,
"name": wf.name,
"description": wf.description,
"trigger_type": trigger_type_desc,
"state": state_desc,
"run_count": wf.run_count,
"timer": wf.timer,
"event_type": wf.event_type,
"add_time": wf.add_time,
"last_time": wf.last_time,
"current_action": wf.current_action
}
# wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。
simplified_workflows.append(simplified)
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)
return result_json
except Exception as e:
logger.error(f"查询工作流失败: {e}", exc_info=True)
return f"查询工作流时发生错误: {str(e)}"

View File

@@ -15,7 +15,7 @@ MAX_READ_SIZE = 50 * 1024
class ReadFileInput(BaseModel):
"""Input parameters for read file tool"""
"""文件读取工具的输入参数模型。"""
file_path: str = Field(..., description="The absolute path of the file to read")
start_line: Optional[int] = Field(None, description="The starting line number (1-based, inclusive). If not provided, reading starts from the beginning of the file.")
end_line: Optional[int] = Field(None, description="The ending line number (1-based, inclusive). If not provided, reading goes until the end of the file.")

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.chain.workflow import WorkflowChain
from app.db import AsyncSessionFactory
from app.db.workflow_oper import WorkflowOper
from app.log import logger
@@ -65,26 +64,23 @@ class RunWorkflowTool(MoviePilotTool):
)
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
workflow_oper = WorkflowOper(db)
workflow = await workflow_oper.async_get(workflow_id)
workflow = await WorkflowOper().async_get(workflow_id)
if not workflow:
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
if not workflow:
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
# 工作流执行链路包含大量同步步骤,统一放到 workflow 线程池。
state, errmsg = await self.run_blocking(
"workflow",
self._run_workflow_sync,
workflow.id,
from_begin,
)
# 工作流执行链路包含大量同步步骤,统一放到 workflow 线程池。
state, errmsg = await self.run_blocking(
"workflow",
self._run_workflow_sync,
workflow.id,
from_begin,
)
if not state:
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
else:
return f"工作流执行成功:{workflow.name} (ID: {workflow.id})"
if not state:
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
else:
return f"工作流执行成功:{workflow.name} (ID: {workflow.id})"
except Exception as e:
logger.error(f"执行工作流失败: {e}", exc_info=True)
return f"执行工作流时发生错误: {str(e)}"

View File

@@ -8,8 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.core.event import eventmanager
from app.db import AsyncSessionFactory
from app.db.models.site import Site
from app.db.site_oper import SiteOper
from app.log import logger
from app.schemas.types import EventType
from app.utils.string import StringUtils
@@ -127,108 +126,106 @@ class UpdateSiteTool(MoviePilotTool):
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 获取站点
site = await Site.async_get(db, site_id)
if not site:
return json.dumps(
{"success": False, "message": f"站点不存在: {site_id}"},
ensure_ascii=False,
)
# 构建更新字典
site_dict = {}
# 基本信息
if name is not None:
site_dict["name"] = name
# URL处理需要校正格式
if url is not None:
_scheme, _netloc = StringUtils.get_url_netloc(url)
site_dict["url"] = f"{_scheme}://{_netloc}/"
if pri is not None:
site_dict["pri"] = pri
if rss is not None:
site_dict["rss"] = rss
# 认证信息
if cookie is not None:
site_dict["cookie"] = cookie
if ua is not None:
site_dict["ua"] = ua
if apikey is not None:
site_dict["apikey"] = apikey
if token is not None:
site_dict["token"] = token
# 配置选项
if proxy is not None:
site_dict["proxy"] = proxy
if filter is not None:
site_dict["filter"] = filter
if note is not None:
site_dict["note"] = note
if timeout is not None:
site_dict["timeout"] = timeout
# 流控设置
if limit_interval is not None:
site_dict["limit_interval"] = limit_interval
if limit_count is not None:
site_dict["limit_count"] = limit_count
if limit_seconds is not None:
site_dict["limit_seconds"] = limit_seconds
# 状态和下载器
if is_active is not None:
site_dict["is_active"] = is_active
if downloader is not None:
site_dict["downloader"] = downloader
# 如果没有要更新的字段
if not site_dict:
return json.dumps(
{"success": False, "message": "没有提供要更新的字段"},
ensure_ascii=False,
)
# 更新站点
await site.async_update(db, site_dict)
# 重新获取更新后的站点数据
updated_site = await Site.async_get(db, site_id)
# 发送站点更新事件
await eventmanager.async_send_event(
EventType.SiteUpdated,
{"domain": updated_site.domain if updated_site else site.domain},
site_oper = SiteOper()
site = await site_oper.async_get(site_id)
if not site:
return json.dumps(
{"success": False, "message": f"站点不存在: {site_id}"},
ensure_ascii=False,
)
# 构建返回结果
result = {
"success": True,
"message": f"站点 #{site_id} 更新成功",
"site_id": site_id,
"updated_fields": list(site_dict.keys()),
# 构建更新字典
site_dict = {}
# 基本信息
if name is not None:
site_dict["name"] = name
# URL处理需要校正格式
if url is not None:
_scheme, _netloc = StringUtils.get_url_netloc(url)
site_dict["url"] = f"{_scheme}://{_netloc}/"
if pri is not None:
site_dict["pri"] = pri
if rss is not None:
site_dict["rss"] = rss
# 认证信息
if cookie is not None:
site_dict["cookie"] = cookie
if ua is not None:
site_dict["ua"] = ua
if apikey is not None:
site_dict["apikey"] = apikey
if token is not None:
site_dict["token"] = token
# 配置选项
if proxy is not None:
site_dict["proxy"] = proxy
if filter is not None:
site_dict["filter"] = filter
if note is not None:
site_dict["note"] = note
if timeout is not None:
site_dict["timeout"] = timeout
# 流控设置
if limit_interval is not None:
site_dict["limit_interval"] = limit_interval
if limit_count is not None:
site_dict["limit_count"] = limit_count
if limit_seconds is not None:
site_dict["limit_seconds"] = limit_seconds
# 状态和下载器
if is_active is not None:
site_dict["is_active"] = is_active
if downloader is not None:
site_dict["downloader"] = downloader
# 如果没有要更新的字段
if not site_dict:
return json.dumps(
{"success": False, "message": "没有提供要更新的字段"},
ensure_ascii=False,
)
# 更新站点
await site_oper.async_update(site_id, site_dict)
# 重新获取更新后的站点数据
updated_site = await site_oper.async_get(site_id)
# 发送站点更新事件
await eventmanager.async_send_event(
EventType.SiteUpdated,
{"domain": updated_site.domain if updated_site else site.domain},
)
# 构建返回结果
result = {
"success": True,
"message": f"站点 #{site_id} 更新成功",
"site_id": site_id,
"updated_fields": list(site_dict.keys()),
}
if updated_site:
result["site"] = {
"id": updated_site.id,
"name": updated_site.name,
"domain": updated_site.domain,
"url": updated_site.url,
"pri": updated_site.pri,
"is_active": updated_site.is_active,
"downloader": updated_site.downloader,
"proxy": updated_site.proxy,
"timeout": updated_site.timeout,
}
if updated_site:
result["site"] = {
"id": updated_site.id,
"name": updated_site.name,
"domain": updated_site.domain,
"url": updated_site.url,
"pri": updated_site.pri,
"is_active": updated_site.is_active,
"downloader": updated_site.downloader,
"proxy": updated_site.proxy,
"timeout": updated_site.timeout,
}
return json.dumps(result, ensure_ascii=False, indent=2)
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"更新站点失败: {str(e)}"

View File

@@ -8,8 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.core.event import eventmanager
from app.db import AsyncSessionFactory
from app.db.models.subscribe import Subscribe
from app.db.subscribe_oper import SubscribeOper
from app.log import logger
from app.schemas.types import EventType
@@ -157,149 +156,147 @@ class UpdateSubscribeTool(MoviePilotTool):
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 获取订阅
subscribe = await Subscribe.async_get(db, subscribe_id)
if not subscribe:
return json.dumps(
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
ensure_ascii=False,
)
# 保存旧数据用于事件
old_subscribe_dict = subscribe.to_dict()
# 构建更新字典
subscribe_dict = {}
# 基本信息
if name is not None:
subscribe_dict["name"] = name
if year is not None:
subscribe_dict["year"] = year
if season is not None:
subscribe_dict["season"] = season
# 集数相关
if total_episode is not None:
subscribe_dict["total_episode"] = total_episode
# 如果总集数增加,缺失集数也要相应增加
if total_episode > (subscribe.total_episode or 0):
old_lack = subscribe.lack_episode or 0
subscribe_dict["lack_episode"] = old_lack + (
total_episode - (subscribe.total_episode or 0)
)
# 标记为手动修改过总集数
subscribe_dict["manual_total_episode"] = 1
# 缺失集数处理(只有在没有提供总集数时才单独处理)
# 注意:如果 lack_episode 为 0不更新避免更新为0
if lack_episode is not None and total_episode is None:
if lack_episode > 0:
subscribe_dict["lack_episode"] = lack_episode
# 如果 lack_episode 为 0不添加到更新字典中保持原值或由总集数逻辑处理
if start_episode is not None:
subscribe_dict["start_episode"] = start_episode
# 过滤规则
if quality is not None:
subscribe_dict["quality"] = quality
if resolution is not None:
subscribe_dict["resolution"] = resolution
if effect is not None:
subscribe_dict["effect"] = effect
if include is not None:
subscribe_dict["include"] = include
if exclude is not None:
subscribe_dict["exclude"] = exclude
if filter is not None:
subscribe_dict["filter"] = filter
# 状态
if state is not None:
valid_states = ["R", "P", "S", "N"]
if state not in valid_states:
return json.dumps(
{
"success": False,
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
},
ensure_ascii=False,
)
subscribe_dict["state"] = state
# 下载配置
if sites is not None:
subscribe_dict["sites"] = sites
if downloader is not None:
subscribe_dict["downloader"] = downloader
if save_path is not None:
subscribe_dict["save_path"] = save_path
if best_version is not None:
subscribe_dict["best_version"] = best_version
if best_version_full is not None:
subscribe_dict["best_version_full"] = best_version_full
# 其他配置
if custom_words is not None:
subscribe_dict["custom_words"] = custom_words
if media_category is not None:
subscribe_dict["media_category"] = media_category
if episode_group is not None:
subscribe_dict["episode_group"] = episode_group
# 如果没有要更新的字段
if not subscribe_dict:
return json.dumps(
{"success": False, "message": "没有提供要更新的字段"},
ensure_ascii=False,
)
# 更新订阅
await subscribe.async_update(db, subscribe_dict)
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
# 发送订阅调整事件
await eventmanager.async_send_event(
EventType.SubscribeModified,
{
"subscribe_id": subscribe_id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict()
if updated_subscribe
else {},
},
subscribe_oper = SubscribeOper()
subscribe = await subscribe_oper.async_get(subscribe_id)
if not subscribe:
return json.dumps(
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
ensure_ascii=False,
)
# 构建返回结果
result = {
"success": True,
"message": f"订阅 #{subscribe_id} 更新成功",
# 保存旧数据用于事件
old_subscribe_dict = subscribe.to_dict()
# 构建更新字典
subscribe_dict = {}
# 基本信息
if name is not None:
subscribe_dict["name"] = name
if year is not None:
subscribe_dict["year"] = year
if season is not None:
subscribe_dict["season"] = season
# 集数相关
if total_episode is not None:
subscribe_dict["total_episode"] = total_episode
# 如果总集数增加,缺失集数也要相应增加
if total_episode > (subscribe.total_episode or 0):
old_lack = subscribe.lack_episode or 0
subscribe_dict["lack_episode"] = old_lack + (
total_episode - (subscribe.total_episode or 0)
)
# 标记为手动修改过总集数
subscribe_dict["manual_total_episode"] = 1
# 缺失集数处理(只有在没有提供总集数时才单独处理)
# 注意:如果 lack_episode 为 0不更新避免更新为0
if lack_episode is not None and total_episode is None:
if lack_episode > 0:
subscribe_dict["lack_episode"] = lack_episode
# 如果 lack_episode 为 0不添加到更新字典中保持原值或由总集数逻辑处理
if start_episode is not None:
subscribe_dict["start_episode"] = start_episode
# 过滤规则
if quality is not None:
subscribe_dict["quality"] = quality
if resolution is not None:
subscribe_dict["resolution"] = resolution
if effect is not None:
subscribe_dict["effect"] = effect
if include is not None:
subscribe_dict["include"] = include
if exclude is not None:
subscribe_dict["exclude"] = exclude
if filter is not None:
subscribe_dict["filter"] = filter
# 状态
if state is not None:
valid_states = ["R", "P", "S", "N"]
if state not in valid_states:
return json.dumps(
{
"success": False,
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
},
ensure_ascii=False,
)
subscribe_dict["state"] = state
# 下载配置
if sites is not None:
subscribe_dict["sites"] = sites
if downloader is not None:
subscribe_dict["downloader"] = downloader
if save_path is not None:
subscribe_dict["save_path"] = save_path
if best_version is not None:
subscribe_dict["best_version"] = best_version
if best_version_full is not None:
subscribe_dict["best_version_full"] = best_version_full
# 其他配置
if custom_words is not None:
subscribe_dict["custom_words"] = custom_words
if media_category is not None:
subscribe_dict["media_category"] = media_category
if episode_group is not None:
subscribe_dict["episode_group"] = episode_group
# 如果没有要更新的字段
if not subscribe_dict:
return json.dumps(
{"success": False, "message": "没有提供要更新的字段"},
ensure_ascii=False,
)
# 更新订阅
await subscribe_oper.async_update(subscribe_id, subscribe_dict)
# 重新获取更新后的订阅数据
updated_subscribe = await subscribe_oper.async_get(subscribe_id)
# 发送订阅调整事件
await eventmanager.async_send_event(
EventType.SubscribeModified,
{
"subscribe_id": subscribe_id,
"updated_fields": list(subscribe_dict.keys()),
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict()
if updated_subscribe
else {},
},
)
# 构建返回结果
result = {
"success": True,
"message": f"订阅 #{subscribe_id} 更新成功",
"subscribe_id": subscribe_id,
"updated_fields": list(subscribe_dict.keys()),
}
if updated_subscribe:
result["subscribe"] = {
"id": updated_subscribe.id,
"name": updated_subscribe.name,
"year": updated_subscribe.year,
"type": updated_subscribe.type,
"season": updated_subscribe.season,
"state": updated_subscribe.state,
"total_episode": updated_subscribe.total_episode,
"lack_episode": updated_subscribe.lack_episode,
"start_episode": updated_subscribe.start_episode,
"quality": updated_subscribe.quality,
"resolution": updated_subscribe.resolution,
"effect": updated_subscribe.effect,
}
if updated_subscribe:
result["subscribe"] = {
"id": updated_subscribe.id,
"name": updated_subscribe.name,
"year": updated_subscribe.year,
"type": updated_subscribe.type,
"season": updated_subscribe.season,
"state": updated_subscribe.state,
"total_episode": updated_subscribe.total_episode,
"lack_episode": updated_subscribe.lack_episode,
"start_episode": updated_subscribe.start_episode,
"quality": updated_subscribe.quality,
"resolution": updated_subscribe.resolution,
"effect": updated_subscribe.effect,
}
return json.dumps(result, ensure_ascii=False, indent=2)
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"更新订阅失败: {str(e)}"

View File

@@ -11,7 +11,10 @@ from app.agent.tools.tags import ToolTag
from app.agent.tools.impl._system_setting_utils import (
SettingSpec,
get_default_list_match_field,
is_secret_setting_key,
redact_secret_value,
resolve_setting_spec,
should_redact_setting,
)
from app.core.config import settings
from app.core.event import eventmanager
@@ -102,12 +105,14 @@ class UpdateSystemSettingsTool(MoviePilotTool):
@staticmethod
def _load_setting_value(spec: SettingSpec):
"""读取指定设置项的当前值。"""
if spec.source == "settings":
return getattr(settings, spec.key)
return SystemConfigOper().get(spec.key)
return SystemConfigOper().get(spec.systemconfig_key)
@staticmethod
def _normalize_systemconfig_value(value: Any):
"""规范化写入 SystemConfig 的空列表值。"""
if isinstance(value, list):
filtered = [item for item in value if item is not None]
return filtered or None
@@ -221,10 +226,7 @@ class UpdateSystemSettingsTool(MoviePilotTool):
**kwargs,
) -> str:
logger.info(
"执行工具: %s, setting_key=%s, operation=%s",
self.name,
setting_key,
operation,
f"执行工具: {self.name}, setting_key={setting_key}, operation={operation}"
)
try:
@@ -266,7 +268,10 @@ class UpdateSystemSettingsTool(MoviePilotTool):
else:
normalized_value = self._normalize_systemconfig_value(next_value)
event_value = normalized_value
success = await SystemConfigOper().async_set(spec.key, normalized_value)
success = await SystemConfigOper().async_set(
spec.systemconfig_key,
normalized_value,
)
changed = success is True
if changed:
@@ -280,6 +285,26 @@ class UpdateSystemSettingsTool(MoviePilotTool):
)
saved_value = self._load_setting_value(spec)
redact_values = (
should_redact_setting(spec, saved_value)
or should_redact_setting(spec, current_value)
)
response_previous_value = (
redact_secret_value(
current_value,
redact_scalar=is_secret_setting_key(spec.key),
)
if redact_values
else current_value
)
response_saved_value = (
redact_secret_value(
saved_value,
redact_scalar=is_secret_setting_key(spec.key),
)
if redact_values
else saved_value
)
if not changed and not message:
message = "配置值未发生变化"
@@ -295,8 +320,9 @@ class UpdateSystemSettingsTool(MoviePilotTool):
"group": spec.group,
"label": spec.label,
},
"previous_value": current_value,
"saved_value": saved_value,
"values_redacted": redact_values,
"previous_value": response_previous_value,
"saved_value": response_saved_value,
},
ensure_ascii=False,
indent=2,

View File

@@ -12,7 +12,7 @@ from app.log import logger
class WriteFileInput(BaseModel):
"""Input parameters for write file tool"""
"""文件写入工具的输入参数模型。"""
file_path: str = Field(..., description="The absolute path of the file to write")
content: str = Field(..., description="The content to write into the file")
@@ -26,7 +26,7 @@ class WriteFileTool(MoviePilotTool):
]
description: str = (
"Write full content to a local text file. Non-admin users can only write "
"inside the MoviePilot config, Agent memory/activity, and log directories."
"inside the MoviePilot Agent config and log directories."
)
args_schema: Type[BaseModel] = WriteFileInput

View File

@@ -114,6 +114,12 @@ class DownloadHistoryOper(DbOper):
"""
return DownloadHistory.list_by_page(self._db, page, count)
async def async_delete_history(self, historyid: int):
"""
异步删除下载记录。
"""
await DownloadHistory.async_delete(self._db, historyid)
def truncate(self):
"""
清空下载记录

View File

@@ -79,6 +79,15 @@ class SiteOper(DbOper):
site.update(self._db, payload)
return site
async def async_update(self, sid: int, payload: dict) -> Site:
"""
异步更新站点。
"""
site = await self.async_get(sid)
if site:
await site.async_update(self._db, payload)
return site
def get_by_domain(self, domain: str) -> Site:
"""
按域名获取站点
@@ -170,6 +179,16 @@ class SiteOper(DbOper):
"""
return SiteUserData.get_by_domain(self._db, domain=domain, workdate=workdate)
async def async_get_userdata_by_domain(
self, domain: str, workdate: Optional[str] = None
) -> List[SiteUserData]:
"""
异步获取站点用户数据。
"""
return await SiteUserData.async_get_by_domain(
self._db, domain=domain, workdate=workdate
)
def get_userdata_by_date(self, date: str) -> List[SiteUserData]:
"""
获取站点用户数据

View File

@@ -169,6 +169,22 @@ class SubscribeOper(DbOper):
"""
await Subscribe.async_delete(self._db, rid=sid)
async def async_update(self, sid: int, payload: dict) -> Subscribe:
"""
异步更新订阅。
"""
subscribe = await self.async_get(sid)
if subscribe:
payload = _normalize_integer_flags(payload)
await subscribe.async_update(self._db, payload)
return subscribe
async def async_update_filter_groups(self, sid: int, filter_groups: list) -> Subscribe:
"""
异步更新订阅使用的过滤规则组。
"""
return await self.async_update(sid, {"filter_groups": filter_groups})
def update(self, sid: int, payload: dict) -> Subscribe:
"""
更新订阅

View File

@@ -0,0 +1,26 @@
from typing import List, Optional
from app.db import DbOper
from app.db.models.subscribehistory import SubscribeHistory
class SubscribeHistoryOper(DbOper):
"""
订阅历史管理。
"""
async def async_list_by_type(
self,
mtype: str,
page: Optional[int] = 1,
count: Optional[int] = 30,
) -> List[SubscribeHistory]:
"""
异步按媒体类型分页查询订阅历史。
"""
return await SubscribeHistory.async_list_by_type(
self._db,
mtype=mtype,
page=page,
count=count,
)

View File

@@ -26,6 +26,51 @@ class TransferHistoryOper(DbOper):
"""
return await TransferHistory.async_get(self._db, historyid)
async def async_list_by_title(
self,
title: str,
page: Optional[int] = 1,
count: Optional[int] = 30,
status: Optional[bool] = None,
) -> List[TransferHistory]:
"""
异步按标题分页查询转移记录。
"""
return await TransferHistory.async_list_by_title(
self._db, title=title, page=page, count=count, status=status
)
async def async_list_by_page(
self,
page: Optional[int] = 1,
count: Optional[int] = 30,
status: Optional[bool] = None,
) -> List[TransferHistory]:
"""
异步分页查询转移记录。
"""
return await TransferHistory.async_list_by_page(
self._db, page=page, count=count, status=status
)
async def async_count(self, status: Optional[bool] = None) -> int:
"""
异步统计转移记录数量。
"""
return await TransferHistory.async_count(self._db, status=status)
async def async_count_by_title(
self,
title: str,
status: Optional[bool] = None,
) -> int:
"""
异步按标题统计转移记录数量。
"""
return await TransferHistory.async_count_by_title(
self._db, title=title, status=status
)
def get_by_title(self, title: str) -> List[TransferHistory]:
"""
按标题查询转移记录