mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-02 13:21:35 +08:00
Refactor movie pilot config and test coverage
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
@@ -148,6 +149,16 @@ class _SessionUsageSnapshot:
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompiledAgentBundle:
|
||||
"""会话内可复用的 Agent 图及其构造签名。"""
|
||||
|
||||
signature: tuple[Any, ...]
|
||||
agent: Any
|
||||
streaming: bool
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class _ThinkTagStripper:
|
||||
"""
|
||||
流式剥离 <think>...</think> 标签的辅助类。
|
||||
@@ -280,6 +291,8 @@ class MoviePilotAgent:
|
||||
self._llm_runtime_config: Optional[Dict[str, Any]] = None
|
||||
self._llm_provider_selection: Dict[str, Any] = {}
|
||||
self._agent_started_at: Optional[datetime] = None
|
||||
self._compiled_agent_bundle: Optional[_CompiledAgentBundle] = None
|
||||
self._last_agent_cache_hit = False
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
@@ -980,6 +993,90 @@ class MoviePilotAgent:
|
||||
allow_message_tools=self.allow_message_tools,
|
||||
)
|
||||
|
||||
def _refresh_tool_context(self, values: Dict[str, object]) -> None:
|
||||
"""
|
||||
刷新本轮工具共享上下文。
|
||||
|
||||
工具对象可能随会话内 Agent 图缓存被复用,因此这里保留 dict 对象本身,
|
||||
只替换其中内容,确保缓存工具看到的是最新权限与回复状态。
|
||||
"""
|
||||
self._tool_context.clear()
|
||||
self._tool_context.update(values)
|
||||
|
||||
@staticmethod
|
||||
def _public_runtime_config_signature(runtime_config: Dict[str, Any]) -> tuple:
|
||||
"""生成不包含密钥明文的 LLM 运行时签名。"""
|
||||
api_key = runtime_config.get("api_key") or ""
|
||||
api_key_digest = (
|
||||
hashlib.sha256(str(api_key).encode("utf-8")).hexdigest()[:12]
|
||||
if api_key
|
||||
else ""
|
||||
)
|
||||
return (
|
||||
runtime_config.get("provider"),
|
||||
runtime_config.get("model"),
|
||||
api_key_digest,
|
||||
runtime_config.get("base_url"),
|
||||
runtime_config.get("base_url_preset"),
|
||||
runtime_config.get("user_agent"),
|
||||
bool(runtime_config.get("use_proxy")),
|
||||
runtime_config.get("thinking_level"),
|
||||
)
|
||||
|
||||
async def _agent_bundle_signature(self, streaming: bool) -> tuple[Any, ...]:
|
||||
"""构造会话内 Agent 图缓存签名。"""
|
||||
runtime_config = await self._resolve_llm_runtime_config()
|
||||
return (
|
||||
streaming,
|
||||
self.channel,
|
||||
self.source,
|
||||
self.user_id,
|
||||
self.username,
|
||||
self.allow_message_tools,
|
||||
bool(self._tool_context.get("is_admin")),
|
||||
self.has_message_context,
|
||||
self.is_background,
|
||||
settings.AI_AGENT_VERBOSE,
|
||||
settings.LLM_MAX_TOOLS,
|
||||
settings.LLM_MAX_ITERATIONS,
|
||||
self._public_runtime_config_signature(runtime_config),
|
||||
agent_runtime_manager.current_signature(),
|
||||
)
|
||||
|
||||
def _get_cached_agent(
|
||||
self, signature: tuple[Any, ...], streaming: bool
|
||||
) -> Optional[Any]:
|
||||
"""按签名读取当前会话已编译的 Agent 图。"""
|
||||
bundle = self._compiled_agent_bundle
|
||||
if (
|
||||
bundle
|
||||
and bundle.streaming == streaming
|
||||
and bundle.signature == signature
|
||||
):
|
||||
return bundle.agent
|
||||
return None
|
||||
|
||||
def _cache_agent(
|
||||
self,
|
||||
*,
|
||||
signature: tuple[Any, ...],
|
||||
agent: Any,
|
||||
streaming: bool,
|
||||
) -> Any:
|
||||
"""保存当前会话可复用的 Agent 图。"""
|
||||
self._compiled_agent_bundle = _CompiledAgentBundle(
|
||||
signature=signature,
|
||||
agent=agent,
|
||||
streaming=streaming,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def _latest_turn_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""从完整历史中提取本轮新增用户消息。"""
|
||||
return [messages[-1]] if messages else []
|
||||
|
||||
def _initialize_subagent_tools(self) -> List:
|
||||
"""
|
||||
初始化子代理专用静默工具列表。
|
||||
@@ -1006,6 +1103,13 @@ class MoviePilotAgent:
|
||||
:param streaming: 是否启用流式输出
|
||||
"""
|
||||
try:
|
||||
bundle_signature = await self._agent_bundle_signature(streaming)
|
||||
cached_agent = self._get_cached_agent(bundle_signature, streaming)
|
||||
self._last_agent_cache_hit = bool(cached_agent)
|
||||
if cached_agent:
|
||||
logger.debug(f"复用会话内 Agent 图: session_id={self.session_id}")
|
||||
return cached_agent
|
||||
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
||||
|
||||
@@ -1113,13 +1217,18 @@ class MoviePilotAgent:
|
||||
)
|
||||
)
|
||||
|
||||
return create_agent(
|
||||
agent = create_agent(
|
||||
model=agent_model,
|
||||
tools=[*tools, *skill_tools, *activity_log_tools],
|
||||
system_prompt=system_prompt,
|
||||
middleware=middlewares,
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
return self._cache_agent(
|
||||
signature=bundle_signature,
|
||||
agent=agent,
|
||||
streaming=streaming,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建 Agent 失败: {e}")
|
||||
raise e
|
||||
@@ -1137,12 +1246,15 @@ class MoviePilotAgent:
|
||||
user_display_saved = False
|
||||
try:
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, "
|
||||
f"Agent推理: session_id={self.session_id}, "
|
||||
f"input_chars={len(message or '')}, "
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}, "
|
||||
f"audio_input={has_audio_input}"
|
||||
)
|
||||
self._tool_context = await self._build_tool_context(
|
||||
should_dispatch_reply=self.should_dispatch_reply
|
||||
self._refresh_tool_context(
|
||||
await self._build_tool_context(
|
||||
should_dispatch_reply=self.should_dispatch_reply
|
||||
)
|
||||
)
|
||||
self._streamed_output = ""
|
||||
|
||||
@@ -1330,6 +1442,11 @@ class MoviePilotAgent:
|
||||
|
||||
# 创建智能体(根据是否流式传入不同 LLM)
|
||||
agent = await self._create_agent(streaming=use_streaming)
|
||||
input_messages = (
|
||||
self._latest_turn_messages(messages)
|
||||
if self._last_agent_cache_hit
|
||||
else messages
|
||||
)
|
||||
|
||||
if use_streaming:
|
||||
self.stream_handler.set_dispatch_policy(
|
||||
@@ -1348,7 +1465,7 @@ class MoviePilotAgent:
|
||||
# 流式运行智能体,token 直接推送到 stream_handler
|
||||
await self._stream_agent_tokens(
|
||||
agent=agent,
|
||||
messages={"messages": messages},
|
||||
messages={"messages": input_messages},
|
||||
config=agent_config,
|
||||
on_token=self._handle_stream_text,
|
||||
)
|
||||
@@ -1387,7 +1504,7 @@ class MoviePilotAgent:
|
||||
else:
|
||||
# 非流式模式:后台任务或渠道不支持消息编辑
|
||||
await agent.ainvoke(
|
||||
{"messages": messages},
|
||||
{"messages": input_messages},
|
||||
config=agent_config,
|
||||
)
|
||||
|
||||
@@ -1446,9 +1563,11 @@ class MoviePilotAgent:
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
self._compiled_agent_bundle = None
|
||||
execution_error = "任务已取消"
|
||||
return "任务已取消", {}
|
||||
except Exception as e:
|
||||
self._compiled_agent_bundle = None
|
||||
execution_error = str(e)
|
||||
if self._messages_have_image_input(messages) and self._is_unsupported_image_input_error(e):
|
||||
logger.warning(
|
||||
@@ -1492,6 +1611,7 @@ class MoviePilotAgent:
|
||||
"""
|
||||
清理智能体资源
|
||||
"""
|
||||
self._compiled_agent_bundle = None
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
|
||||
@@ -1628,6 +1628,37 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
)
|
||||
return None
|
||||
|
||||
def _resolve_cached_model_record(
|
||||
self,
|
||||
provider_id: str,
|
||||
model_id: Optional[str],
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
transport: str = "openai",
|
||||
) -> dict[str, Any] | None:
|
||||
"""从缓存中的模型元数据构造轻量模型记录,不触发远端模型列表刷新。"""
|
||||
if not model_id:
|
||||
return None
|
||||
metadata = self.resolve_cached_model_metadata(
|
||||
provider_id,
|
||||
model_id,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
) or {}
|
||||
if not metadata:
|
||||
return self._normalize_model_record(
|
||||
model_id=model_id,
|
||||
transport=transport,
|
||||
source="configured",
|
||||
)
|
||||
return self._normalize_model_record(
|
||||
model_id=model_id,
|
||||
display_name=metadata.get("name") or model_id,
|
||||
metadata=metadata,
|
||||
transport=transport,
|
||||
source="models.dev-cache",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_record(
|
||||
model_id: str,
|
||||
@@ -2104,7 +2135,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
try:
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
except Exception as err:
|
||||
print(err)
|
||||
logger.debug(f"解析 JWT token 内容失败: {err}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
@@ -2587,40 +2618,29 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
)
|
||||
normalized_api_key = str(api_key or "").strip() or None
|
||||
normalized_base_url = self._sanitize_base_url(base_url)
|
||||
model_record = None
|
||||
if model:
|
||||
try:
|
||||
model_record = next(
|
||||
(
|
||||
item
|
||||
for item in await self.list_models(
|
||||
normalized_provider_id,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=normalized_base_url_preset_id,
|
||||
user_agent=user_agent,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
if item["id"] == model
|
||||
),
|
||||
None,
|
||||
)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
model_record = None
|
||||
default_transport = (
|
||||
"anthropic" if resolved_runtime == "anthropic_compatible" else "openai"
|
||||
)
|
||||
model_record = self._resolve_cached_model_record(
|
||||
normalized_provider_id,
|
||||
model,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=normalized_base_url_preset_id,
|
||||
transport=default_transport,
|
||||
)
|
||||
model_metadata = self.resolve_cached_model_metadata(
|
||||
normalized_provider_id,
|
||||
model,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=normalized_base_url_preset_id,
|
||||
)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"provider_id": normalized_provider_id,
|
||||
"runtime": resolved_runtime,
|
||||
"model_id": model,
|
||||
"model_record": model_record,
|
||||
"model_metadata": await self.resolve_model_metadata(
|
||||
normalized_provider_id,
|
||||
model,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=normalized_base_url_preset_id,
|
||||
use_proxy=use_proxy,
|
||||
),
|
||||
"model_metadata": model_metadata,
|
||||
"default_headers": None,
|
||||
"use_responses_api": None,
|
||||
"auth_mode": "api_key",
|
||||
@@ -2631,8 +2651,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
try:
|
||||
auth = await self._resolve_chatgpt_oauth()
|
||||
except Exception as err:
|
||||
print(err)
|
||||
pass
|
||||
logger.debug(f"解析 ChatGPT OAuth 鉴权失败,回退 API Key 模式: {err}")
|
||||
|
||||
if auth:
|
||||
headers = {"originator": "moviepilot"}
|
||||
|
||||
@@ -7,12 +7,14 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, NotRequired, Optional, TypedDict
|
||||
|
||||
import anyio
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
@@ -579,14 +581,29 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
|
||||
entry = f"- **{now_str}** {summary}\n"
|
||||
try:
|
||||
if await log_path.exists():
|
||||
existing = await log_path.read_text(encoding="utf-8", errors="replace")
|
||||
await log_path.write_text(existing + entry, encoding="utf-8")
|
||||
async with await anyio.open_file(
|
||||
log_path,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
) as stream:
|
||||
await stream.write(entry)
|
||||
else:
|
||||
header = f"# {today_str} 活动日志\n\n"
|
||||
await log_path.write_text(header + entry, encoding="utf-8")
|
||||
logger.debug("Activity logged: %s", summary[:80])
|
||||
try:
|
||||
fd = os.open(log_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
|
||||
except FileExistsError:
|
||||
async with await anyio.open_file(
|
||||
log_path,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
) as stream:
|
||||
await stream.write(entry)
|
||||
else:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as stream:
|
||||
stream.write(header + entry)
|
||||
logger.debug(f"Activity logged: {summary[:80]}")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to append activity log: %s", e)
|
||||
logger.warning(f"Failed to append activity log: {e}")
|
||||
|
||||
async def _cleanup_old_logs(self) -> None:
|
||||
"""清理超过保留天数的旧日志文件。"""
|
||||
@@ -608,20 +625,16 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
|
||||
file_date = datetime.strptime(match.group(1), "%Y-%m-%d").date()
|
||||
if file_date < cutoff_date:
|
||||
await path.unlink()
|
||||
logger.debug("Cleaned up old activity log: %s", path.name)
|
||||
logger.debug(f"Cleaned up old activity log: {path.name}")
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup old activity logs: %s", e)
|
||||
logger.warning(f"Failed to cleanup old activity logs: {e}")
|
||||
|
||||
async def abefore_agent(
|
||||
self, state: ActivityLogState, runtime: Runtime
|
||||
) -> Optional[ActivityLogStateUpdate]:
|
||||
"""在 Agent 执行前加载近期活动日志。"""
|
||||
# 如果已经加载则跳过
|
||||
if "activity_log_contents" in state:
|
||||
return None
|
||||
|
||||
contents = await self._load_recent_logs()
|
||||
|
||||
# 趁机清理旧日志(低频操作,不影响性能)
|
||||
@@ -709,7 +722,7 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
|
||||
if summary:
|
||||
await self._append_activity(summary)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record activity: %s", e)
|
||||
logger.warning(f"Failed to record activity: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -283,12 +283,7 @@ class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa
|
||||
) -> JobsStateUpdate | None:
|
||||
"""在 Agent 执行前异步加载任务元数据。
|
||||
|
||||
每个会话仅加载一次。若 state 中已有则跳过。
|
||||
"""
|
||||
# 如果 state 中已存在元数据则跳过
|
||||
if "jobs_metadata" in state:
|
||||
return None
|
||||
|
||||
return JobsStateUpdate(
|
||||
jobs_metadata=await load_jobs_metadata(self.sources)
|
||||
)
|
||||
|
||||
@@ -302,7 +302,6 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
"""在代理执行前扫描记忆目录并加载所有 .md 文件的内容。
|
||||
|
||||
自动发现目录下所有 `.md` 文件并加载其内容到状态中。
|
||||
如果状态中尚未存在则进行加载。
|
||||
同时检测记忆文件是否为空,设置 memory_empty 标志位,
|
||||
以便在系统提示词中触发初始化引导流程。
|
||||
|
||||
@@ -314,10 +313,6 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
返回:
|
||||
填充了 memory_contents 和 memory_empty 的状态更新。
|
||||
"""
|
||||
# 如果已经加载则跳过
|
||||
if "memory_contents" in state:
|
||||
return None
|
||||
|
||||
# 扫描目录下所有 .md 文件
|
||||
md_files = await self._scan_memory_files()
|
||||
|
||||
|
||||
@@ -322,7 +322,7 @@ def _extract_version(skill_md: Path) -> int:
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8", errors="replace")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
logger.debug(f"读取技能版本失败: {err}")
|
||||
return 0
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
@@ -627,13 +627,8 @@ class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # no
|
||||
) -> SkillsStateUpdate | None: # ty: ignore[invalid-method-override]
|
||||
"""在 Agent 执行前异步加载技能元数据。
|
||||
|
||||
每个会话仅加载一次。若 state 中已有则跳过。
|
||||
首次加载时,会先将内置技能同步到用户目录(如不存在)。
|
||||
"""
|
||||
# 如果 state 中已存在元数据则跳过
|
||||
if "skills_metadata" in state:
|
||||
return None
|
||||
|
||||
self._sync_bundled_skills()
|
||||
|
||||
all_skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
@@ -197,18 +197,44 @@ def is_subagent_stream_metadata(metadata: Any) -> bool:
|
||||
) == SUBAGENT_STREAM_MARKER_VALUE:
|
||||
return True
|
||||
|
||||
return bool(metadata.get("lc_agent_name") in builtin_subagent_names())
|
||||
return bool(
|
||||
metadata.get("lc_agent_name")
|
||||
in builtin_subagent_names(agent_runtime_manager.current_signature())
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def builtin_subagent_names() -> frozenset[str]:
|
||||
def builtin_subagent_names(
|
||||
runtime_signature: Optional[tuple[tuple[str, int, int], ...]] = None,
|
||||
) -> frozenset[str]:
|
||||
"""返回内置子代理名称集合。"""
|
||||
return frozenset(profile.name for profile in _builtin_subagent_profiles())
|
||||
runtime_signature = runtime_signature or agent_runtime_manager.current_signature()
|
||||
return _cached_builtin_subagent_names(runtime_signature)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _builtin_subagent_profiles() -> tuple[_SubAgentProfile, ...]:
|
||||
@lru_cache(maxsize=8)
|
||||
def _cached_builtin_subagent_names(
|
||||
runtime_signature: tuple[tuple[str, int, int], ...],
|
||||
) -> frozenset[str]:
|
||||
"""按运行时签名缓存内置子代理名称集合。"""
|
||||
return frozenset(
|
||||
profile.name
|
||||
for profile in _builtin_subagent_profiles(runtime_signature)
|
||||
)
|
||||
|
||||
|
||||
def _builtin_subagent_profiles(
|
||||
runtime_signature: Optional[tuple[tuple[str, int, int], ...]] = None,
|
||||
) -> tuple[_SubAgentProfile, ...]:
|
||||
"""从运行时配置目录加载 MoviePilot 子代理定义。"""
|
||||
runtime_signature = runtime_signature or agent_runtime_manager.current_signature()
|
||||
return _cached_builtin_subagent_profiles(runtime_signature)
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cached_builtin_subagent_profiles(
|
||||
runtime_signature: tuple[tuple[str, int, int], ...],
|
||||
) -> tuple[_SubAgentProfile, ...]:
|
||||
"""按运行时签名缓存 MoviePilot 子代理定义。"""
|
||||
definitions = agent_runtime_manager.list_subagents()
|
||||
profiles = tuple(
|
||||
_profile_from_runtime_definition(definition)
|
||||
@@ -237,6 +263,10 @@ def _builtin_subagent_profiles() -> tuple[_SubAgentProfile, ...]:
|
||||
)
|
||||
|
||||
|
||||
builtin_subagent_names.cache_clear = _cached_builtin_subagent_names.cache_clear
|
||||
_builtin_subagent_profiles.cache_clear = _cached_builtin_subagent_profiles.cache_clear
|
||||
|
||||
|
||||
def _profile_from_runtime_definition(
|
||||
definition: SubAgentDefinition,
|
||||
) -> _SubAgentProfile:
|
||||
@@ -1044,6 +1074,7 @@ class SubAgentTaskControlMiddleware(AgentMiddleware):
|
||||
if unfinished_records:
|
||||
logger.info(f"Agent 结束,取消未完成子代理任务: tasks={len(unfinished_records)}")
|
||||
await self._cancel_records(unfinished_records)
|
||||
self._tasks.clear()
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
@@ -1083,9 +1114,8 @@ def create_subagent_middlewares(
|
||||
stream_handler: Any = None,
|
||||
) -> tuple[list[AgentMiddleware], list[BaseTool]]:
|
||||
"""创建子代理中间件列表和任务工具列表。"""
|
||||
_builtin_subagent_profiles.cache_clear()
|
||||
builtin_subagent_names.cache_clear()
|
||||
profiles = _builtin_subagent_profiles()
|
||||
runtime_signature = agent_runtime_manager.current_signature()
|
||||
profiles = _builtin_subagent_profiles(runtime_signature)
|
||||
subagent_middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=profiles,
|
||||
|
||||
@@ -592,22 +592,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果,
|
||||
不会为每次模型回合重复追加一笔 selector LLM 开销。
|
||||
"""
|
||||
if "selected_tool_names" in state:
|
||||
self._log_selection_attempt(
|
||||
_ToolSelectionAttempt(
|
||||
request=ModelRequest(
|
||||
model=self.model,
|
||||
tools=list(self.selection_tools),
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
),
|
||||
selected_tool_names=state.get("selected_tool_names") or [],
|
||||
status="reused",
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
if not self.selection_tools or self.model is None:
|
||||
detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型"
|
||||
self._log_selection_attempt(
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional
|
||||
@@ -243,9 +244,15 @@ class AgentRuntimeManager:
|
||||
self._cache_lock = threading.Lock()
|
||||
self._cached_signature: Optional[tuple[tuple[str, int, int], ...]] = None
|
||||
self._cached_config: Optional[AgentRuntimeConfig] = None
|
||||
self._cached_signature_checked_at = 0.0
|
||||
self._signature_check_interval = 1.0
|
||||
self._layout_ready = False
|
||||
|
||||
def ensure_layout(self) -> None:
|
||||
"""创建目录、同步默认文件,并清理废弃的旧版 runtime 文件。"""
|
||||
with self._cache_lock:
|
||||
if self._layout_ready:
|
||||
return
|
||||
self.agent_root_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.runtime_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -257,11 +264,13 @@ class AgentRuntimeManager:
|
||||
self._remove_obsolete_runtime_files()
|
||||
self._sync_bundled_defaults()
|
||||
self._migrate_root_memory_files()
|
||||
with self._cache_lock:
|
||||
self._layout_ready = True
|
||||
|
||||
def load_runtime_config(self) -> AgentRuntimeConfig:
|
||||
"""加载配置。用户目录损坏时自动回退到内置默认配置。"""
|
||||
self.ensure_layout()
|
||||
signature = self._build_signature()
|
||||
signature = self.current_signature()
|
||||
with self._cache_lock:
|
||||
if self._cached_signature == signature and self._cached_config:
|
||||
return self._cached_config
|
||||
@@ -269,7 +278,7 @@ class AgentRuntimeManager:
|
||||
try:
|
||||
config = self._load_from_root(self.runtime_dir)
|
||||
except AgentRuntimeConfigError as err:
|
||||
logger.warning("Agent 根层配置无效,回退到内置默认配置: %s", err)
|
||||
logger.warning(f"Agent 根层配置无效,回退到内置默认配置: {err}")
|
||||
config = self._load_from_root(self.bundled_defaults_dir)
|
||||
config.used_fallback = True
|
||||
config.warnings.insert(
|
||||
@@ -285,6 +294,25 @@ class AgentRuntimeManager:
|
||||
with self._cache_lock:
|
||||
self._cached_signature = None
|
||||
self._cached_config = None
|
||||
self._cached_signature_checked_at = 0.0
|
||||
self._layout_ready = False
|
||||
|
||||
def current_signature(self) -> tuple[tuple[str, int, int], ...]:
|
||||
"""返回当前运行时配置文件签名,供调用方判断缓存是否仍可复用。"""
|
||||
now = time.monotonic()
|
||||
with self._cache_lock:
|
||||
if (
|
||||
self._cached_signature is not None
|
||||
and now - self._cached_signature_checked_at
|
||||
< self._signature_check_interval
|
||||
):
|
||||
return self._cached_signature
|
||||
|
||||
signature = self._build_signature()
|
||||
with self._cache_lock:
|
||||
self._cached_signature = signature
|
||||
self._cached_signature_checked_at = now
|
||||
return signature
|
||||
|
||||
def set_active_persona(self, persona_query: str) -> AgentRuntimeConfig:
|
||||
"""切换当前激活人格,并立即刷新缓存。"""
|
||||
@@ -308,7 +336,7 @@ class AgentRuntimeManager:
|
||||
)
|
||||
current_path.write_text(document, encoding="utf-8")
|
||||
self.invalidate_cache()
|
||||
logger.info("已切换 Agent 人格: %s", persona.persona_id)
|
||||
logger.info(f"已切换 Agent 人格: {persona.persona_id}")
|
||||
return self.load_runtime_config()
|
||||
|
||||
def list_personas(self) -> list[PersonaDefinition]:
|
||||
@@ -439,7 +467,7 @@ class AgentRuntimeManager:
|
||||
continue
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(path, target)
|
||||
logger.info("已同步默认 Agent 运行时文件: %s", target)
|
||||
logger.info(f"已同步默认 Agent 运行时文件: {target}")
|
||||
|
||||
@classmethod
|
||||
def _should_update_bundled_subagent(
|
||||
@@ -478,7 +506,7 @@ class AgentRuntimeManager:
|
||||
return
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
source.rename(target)
|
||||
logger.info("已迁移旧版 Agent 根配置文件: %s -> %s", source, target)
|
||||
logger.info(f"已迁移旧版 Agent 根配置文件: {source} -> {target}")
|
||||
|
||||
def _remove_obsolete_runtime_files(self) -> None:
|
||||
"""删除不再支持的旧版 Agent 配置文件,避免被误迁移到 memory。"""
|
||||
@@ -487,14 +515,14 @@ class AgentRuntimeManager:
|
||||
if not path.exists() or not path.is_file():
|
||||
continue
|
||||
path.unlink()
|
||||
logger.info("已删除废弃的 Agent 根配置文件: %s", path)
|
||||
logger.info(f"已删除废弃的 Agent 根配置文件: {path}")
|
||||
|
||||
for relative_path in sorted(OBSOLETE_RUNTIME_FILES):
|
||||
path = self.runtime_dir / relative_path
|
||||
if not path.exists() or not path.is_file():
|
||||
continue
|
||||
path.unlink()
|
||||
logger.info("已删除废弃的 Agent 运行时文件: %s", path)
|
||||
logger.info(f"已删除废弃的 Agent 运行时文件: {path}")
|
||||
|
||||
def _migrate_root_memory_files(self) -> None:
|
||||
"""将旧版根目录 memory 文件移入 `config/agent/memory`。"""
|
||||
@@ -505,7 +533,7 @@ class AgentRuntimeManager:
|
||||
if target.exists():
|
||||
continue
|
||||
path.rename(target)
|
||||
logger.info("已迁移旧版 Agent memory 文件: %s -> %s", path, target)
|
||||
logger.info(f"已迁移旧版 Agent memory 文件: {path} -> {target}")
|
||||
|
||||
def _load_from_root(self, root: Path) -> AgentRuntimeConfig:
|
||||
current_persona_path = root / CURRENT_PERSONA_FILE
|
||||
|
||||
@@ -431,7 +431,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
:return: 普通用户允许读写的本地目录列表
|
||||
"""
|
||||
roots = [
|
||||
settings.CONFIG_PATH / "agent"
|
||||
settings.CONFIG_PATH / "agent",
|
||||
settings.LOG_PATH,
|
||||
]
|
||||
resolved_roots = []
|
||||
for root in roots:
|
||||
@@ -467,7 +468,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
allowed_text = "、".join(str(root) for root in allowed_roots)
|
||||
return (
|
||||
resolved_path,
|
||||
f"抱歉,普通用户只能{operation}配置目录、Agent记忆目录和日志目录内的文件或目录:{allowed_text}",
|
||||
f"抱歉,普通用户只能{operation}Agent配置目录和日志目录内的文件或目录:{allowed_text}",
|
||||
)
|
||||
|
||||
async def _check_local_storage_access(
|
||||
|
||||
88
app/agent/tools/impl/_command_safety.py
Normal file
88
app/agent/tools/impl/_command_safety.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Agent 命令工具的安全校验逻辑。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os.path
|
||||
import re
|
||||
import shlex
|
||||
|
||||
|
||||
COMMAND_FORBIDDEN_KEYWORDS = (
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
)
|
||||
|
||||
COMMAND_DANGEROUS_PATTERNS = (
|
||||
re.compile(r"\brm\s+[^;&|]*-[^\s;&|]*[rR][fF]?[^\s;&|]*\s+/(?:\s|$|[;&|])"),
|
||||
re.compile(r"\bdd\s+[^;&|]*(?:of=/dev/(?:sd[a-z]\d*|nvme\d+n\d+p?\d*|disk\d+)|if=/dev/zero)"),
|
||||
re.compile(r"\b(?:mkfs|fdisk|parted|diskutil)\b"),
|
||||
re.compile(r"\b(?:chmod|chown)\s+[^;&|]*-R[^;&|]*\s+/(?:\s|$|[;&|])"),
|
||||
re.compile(r"\b(?:reboot|shutdown|halt|poweroff)\b"),
|
||||
)
|
||||
|
||||
|
||||
def _command_tokens(command: str) -> list[str]:
|
||||
"""尽力解析 shell 命令 token,解析失败时退回空白分割。"""
|
||||
try:
|
||||
return shlex.split(command, posix=True)
|
||||
except ValueError:
|
||||
return re.split(r"\s+", command.strip())
|
||||
|
||||
|
||||
def _contains_recursive_root_delete(command: str) -> bool:
|
||||
"""识别递归删除根目录或一级目录的 rm 命令。"""
|
||||
tokens = _command_tokens(command)
|
||||
if not any(token == "rm" or token.endswith("/rm") for token in tokens):
|
||||
return False
|
||||
has_recursive = any(
|
||||
token.startswith("-") and ("r" in token or "R" in token)
|
||||
for token in tokens
|
||||
)
|
||||
if not has_recursive:
|
||||
return False
|
||||
|
||||
for token in tokens:
|
||||
clean_token = re.match(r"^([^;|&><]+)", token)
|
||||
if not clean_token:
|
||||
continue
|
||||
path_value = clean_token.group(1).strip("\"'")
|
||||
if not path_value.startswith("/"):
|
||||
continue
|
||||
norm_path = os.path.normpath(path_value)
|
||||
if norm_path == "/" or re.match(r"^/[^/]+$", norm_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def detect_dangerous_command(command: str) -> str:
|
||||
"""返回危险命令原因,安全时返回空字符串。"""
|
||||
normalized = str(command or "").strip()
|
||||
if not normalized:
|
||||
return "命令不能为空"
|
||||
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
|
||||
if keyword in normalized:
|
||||
return f"命令包含禁止使用的关键字 '{keyword}'"
|
||||
if _contains_recursive_root_delete(normalized):
|
||||
return "命令疑似递归删除根目录或一级目录"
|
||||
for pattern in COMMAND_DANGEROUS_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "命令匹配高危系统操作模式"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_command_safety(command: str, *, confirmed: bool = False) -> None:
|
||||
"""
|
||||
校验 shell 命令安全性。
|
||||
|
||||
:param command: 待执行命令
|
||||
:param confirmed: 是否已经通过显式参数确认高危操作
|
||||
"""
|
||||
reason = detect_dangerous_command(command)
|
||||
if not reason:
|
||||
return
|
||||
if confirmed and reason != "命令不能为空":
|
||||
return
|
||||
raise ValueError(f"{reason}。如确认需要执行,请设置 confirm_dangerous=true")
|
||||
@@ -5,8 +5,7 @@ import re
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.modules.filter.RuleParser import RuleParser
|
||||
@@ -284,23 +283,22 @@ async def collect_rule_group_usages(
|
||||
continue
|
||||
ensure_usage(name)["used_in_global_best_version"] = True
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
filter_groups = subscribe.filter_groups or []
|
||||
for name in filter_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"type": subscribe.type,
|
||||
"username": subscribe.username,
|
||||
"best_version": bool(subscribe.best_version),
|
||||
}
|
||||
)
|
||||
subscribes = await SubscribeOper().async_list()
|
||||
for subscribe in subscribes:
|
||||
filter_groups = subscribe.filter_groups or []
|
||||
for name in filter_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"type": subscribe.type,
|
||||
"username": subscribe.username,
|
||||
"best_version": bool(subscribe.best_version),
|
||||
}
|
||||
)
|
||||
|
||||
return usage_map
|
||||
|
||||
@@ -482,22 +480,22 @@ async def rename_rule_group_references(old_name: str, new_name: str) -> dict:
|
||||
await save_system_config(config_key, updated)
|
||||
changed["global_settings"][config_key.value] = updated
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = replace_group_name_in_list(original, old_name, new_name)
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe.async_update(db, {"filter_groups": updated})
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = replace_group_name_in_list(original, old_name, new_name)
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe_oper.async_update_filter_groups(subscribe.id, updated)
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
@@ -520,21 +518,21 @@ async def remove_rule_group_references(group_name: str) -> dict:
|
||||
await save_system_config(config_key, updated)
|
||||
changed["global_settings"][config_key.value] = updated
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = [value for value in original if value != group_name]
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe.async_update(db, {"filter_groups": updated})
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = [value for value in original if value != group_name]
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe_oper.async_update_filter_groups(subscribe.id, updated)
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""系统设置工具共用的键解析与分组元数据。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.core.config import Settings
|
||||
from app.schemas.types import SystemConfigKey
|
||||
@@ -15,6 +15,7 @@ class SettingSpec:
|
||||
source: str
|
||||
group: str
|
||||
label: str
|
||||
systemconfig_key: Optional[SystemConfigKey] = None
|
||||
|
||||
|
||||
SYSTEMCONFIG_SETTING_METADATA = {
|
||||
@@ -234,6 +235,7 @@ def _build_specs() -> tuple[dict[str, SettingSpec], dict[str, SettingSpec]]:
|
||||
source="systemconfig",
|
||||
group=metadata.get("group", "misc"),
|
||||
label=metadata.get("label", item.value),
|
||||
systemconfig_key=item,
|
||||
)
|
||||
return core_specs, system_specs
|
||||
|
||||
@@ -333,3 +335,57 @@ def list_setting_specs(
|
||||
|
||||
def get_default_list_match_field(setting_key: str) -> Optional[str]:
|
||||
return LIST_ITEM_MATCH_FIELD_DEFAULTS.get(setting_key)
|
||||
|
||||
|
||||
SECRET_KEYWORDS = (
|
||||
"api_key",
|
||||
"apikey",
|
||||
"token",
|
||||
"secret",
|
||||
"password",
|
||||
"passwd",
|
||||
"cookie",
|
||||
"authorization",
|
||||
"refresh_token",
|
||||
"access_token",
|
||||
)
|
||||
|
||||
|
||||
def is_secret_setting_key(key: str) -> bool:
|
||||
"""判断设置键名是否疑似敏感字段。"""
|
||||
normalized = _normalize_token(key)
|
||||
return any(keyword in normalized for keyword in SECRET_KEYWORDS)
|
||||
|
||||
|
||||
def redact_secret_value(value: Any, *, redact_scalar: bool = False) -> Any:
|
||||
"""递归脱敏配置值中的密钥、Cookie、Token 等敏感字段。"""
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
key: "***"
|
||||
if is_secret_setting_key(str(key))
|
||||
else redact_secret_value(item, redact_scalar=redact_scalar)
|
||||
for key, item in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [
|
||||
redact_secret_value(item, redact_scalar=redact_scalar)
|
||||
for item in value
|
||||
]
|
||||
if isinstance(value, str):
|
||||
return "***" if value and redact_scalar else value
|
||||
return value
|
||||
|
||||
|
||||
def should_redact_setting(spec: SettingSpec, value: Any) -> bool:
|
||||
"""判断某项设置在默认查询响应中是否需要脱敏。"""
|
||||
if is_secret_setting_key(spec.key):
|
||||
return True
|
||||
if isinstance(value, dict):
|
||||
return any(is_secret_setting_key(str(key)) for key in value.keys())
|
||||
if isinstance(value, list):
|
||||
return any(
|
||||
should_redact_setting(spec, item)
|
||||
for item in value
|
||||
if isinstance(item, dict)
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -13,6 +13,7 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.agent.tools.impl._command_safety import validate_command_safety
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
@@ -34,14 +35,6 @@ TERMINAL_PTY_POLL_INTERVAL = 0.05
|
||||
TERMINAL_WAIT_DEFAULT_MS = 1000
|
||||
TERMINAL_WAIT_MAX_MS = 60 * 1000
|
||||
TERMINAL_KILL_GRACE_SECONDS = 3
|
||||
TERMINAL_FORBIDDEN_KEYWORDS = (
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -176,13 +169,9 @@ class _TerminalSessionManager:
|
||||
return merged_env
|
||||
|
||||
@staticmethod
|
||||
def _validate_command(command: str) -> None:
|
||||
def _validate_command(command: str, *, confirmed: bool = False) -> None:
|
||||
"""拒绝明显危险或空白命令。"""
|
||||
if not command or not command.strip():
|
||||
raise ValueError("命令不能为空")
|
||||
for keyword in TERMINAL_FORBIDDEN_KEYWORDS:
|
||||
if keyword in command:
|
||||
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
|
||||
validate_command_safety(command, confirmed=confirmed)
|
||||
|
||||
@staticmethod
|
||||
def _set_nonblocking(fd: int) -> None:
|
||||
@@ -213,9 +202,10 @@ class _TerminalSessionManager:
|
||||
cwd: Optional[str] = None,
|
||||
env: Optional[dict[str, Any]] = None,
|
||||
use_pty: Any = True,
|
||||
confirm_dangerous: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""启动后台命令并立即返回会话 ID。"""
|
||||
self._validate_command(command)
|
||||
self._validate_command(command, confirmed=confirm_dangerous)
|
||||
normalized_cwd = self._normalize_cwd(cwd)
|
||||
normalized_env = self._build_env(env)
|
||||
should_use_pty = self._normalize_bool(use_pty, default=True) and os.name == "posix"
|
||||
@@ -313,7 +303,10 @@ class _TerminalSessionManager:
|
||||
continue
|
||||
except OSError as err:
|
||||
if err.errno not in {errno.EIO, errno.EBADF}:
|
||||
logger.debug("PTY 输出读取异常: session_id=%s, error=%s", session.session_id, err)
|
||||
logger.debug(
|
||||
f"PTY 输出读取异常: session_id={session.session_id}, "
|
||||
f"error={err}"
|
||||
)
|
||||
break
|
||||
|
||||
if not data:
|
||||
@@ -343,7 +336,9 @@ class _TerminalSessionManager:
|
||||
session.mark_finished(session.exit_code)
|
||||
except Exception as err:
|
||||
session.mark_error(str(err))
|
||||
logger.warning("等待 PTY 进程失败: session_id=%s, error=%s", session.session_id, err)
|
||||
logger.warning(
|
||||
f"等待 PTY 进程失败: session_id={session.session_id}, error={err}"
|
||||
)
|
||||
finally:
|
||||
await self._finish_reader_tasks(session)
|
||||
session.close_pty()
|
||||
@@ -358,7 +353,9 @@ class _TerminalSessionManager:
|
||||
session.mark_finished(exit_code)
|
||||
except Exception as err:
|
||||
session.mark_error(str(err))
|
||||
logger.warning("等待管道进程失败: session_id=%s, error=%s", session.session_id, err)
|
||||
logger.warning(
|
||||
f"等待管道进程失败: session_id={session.session_id}, error={err}"
|
||||
)
|
||||
finally:
|
||||
await self._finish_reader_tasks(session)
|
||||
|
||||
|
||||
@@ -228,6 +228,11 @@ class BrowseWebpageTool(MoviePilotTool):
|
||||
return "错误: 'fill_ref' 操作需要提供 value 参数"
|
||||
if browser_action == BrowserAction.EVALUATE and not script:
|
||||
return "错误: 'evaluate' 操作需要提供 script 参数"
|
||||
if (
|
||||
browser_action == BrowserAction.EVALUATE
|
||||
and not await self.is_admin_user()
|
||||
):
|
||||
return "错误: 'evaluate' 操作仅允许管理员使用"
|
||||
if (
|
||||
browser_action in (BrowserAction.FOCUS_TAB, BrowserAction.CLOSE_TAB)
|
||||
and tab_index is None
|
||||
|
||||
@@ -6,8 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@@ -40,9 +39,8 @@ class DeleteDownloadHistoryTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
try:
|
||||
async with AsyncSessionFactory() as db:
|
||||
await DownloadHistory.async_delete(db, history_id)
|
||||
return f"下载历史记录 ID: {history_id} 已成功删除"
|
||||
await DownloadHistoryOper().async_delete_history(history_id)
|
||||
return f"下载历史记录 ID: {history_id} 已成功删除"
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载历史记录失败: {e}", exc_info=True)
|
||||
return f"删除下载历史记录时发生错误: {str(e)}"
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.log import logger
|
||||
|
||||
|
||||
class EditFileInput(BaseModel):
|
||||
"""Input parameters for edit file tool"""
|
||||
"""文件编辑工具的输入参数模型。"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to edit")
|
||||
old_text: str = Field(..., description="The exact old text to be replaced")
|
||||
@@ -27,8 +27,8 @@ class EditFileTool(MoviePilotTool):
|
||||
]
|
||||
description: str = (
|
||||
"Edit a local text file by replacing specific old text with new text. "
|
||||
"Non-admin users can only edit files inside the MoviePilot config, "
|
||||
"Agent memory/activity, and log directories."
|
||||
"Non-admin users can only edit files inside the MoviePilot Agent config "
|
||||
"and log directories."
|
||||
)
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Any, Literal, Optional, TextIO, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.impl._command_safety import validate_command_safety
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._terminal_session import (
|
||||
@@ -30,14 +31,6 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
COMMAND_FORBIDDEN_KEYWORDS = (
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
)
|
||||
|
||||
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
|
||||
@@ -195,6 +188,13 @@ class ExecuteCommandInput(BaseModel):
|
||||
60,
|
||||
description="For action=run, max execution time in seconds.",
|
||||
)
|
||||
confirm_dangerous: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Explicit confirmation for high-risk commands such as recursive root deletion, "
|
||||
"disk formatting, shutdown/reboot, or destructive permission changes."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
@@ -255,34 +255,9 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
return command
|
||||
|
||||
@staticmethod
|
||||
def _validate_command(command: str) -> None:
|
||||
def _validate_command(command: str, *, confirmed: bool = False) -> None:
|
||||
"""复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。"""
|
||||
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
|
||||
if keyword in command:
|
||||
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
|
||||
|
||||
# 检查是否使用了 rm -r/R 删除根目录或一级目录,防止误杀多级目录
|
||||
import re
|
||||
import os.path
|
||||
tokens = re.split(r'\s+', command.strip())
|
||||
if any(t == "rm" or t.endswith("/rm") for t in tokens):
|
||||
has_r = False
|
||||
for token in tokens:
|
||||
if token.startswith("-") and ("r" in token or "R" in token):
|
||||
has_r = True
|
||||
break
|
||||
|
||||
if has_r:
|
||||
for token in tokens:
|
||||
# 提取可能包含目标路径的部分(去除重定向、管道、分号等末尾干扰)
|
||||
m = re.match(r'^([^;\|&><]+)', token)
|
||||
if m:
|
||||
clean_token = m.group(1).strip('"\'')
|
||||
# 仅对绝对路径进行一级目录限制
|
||||
if clean_token.startswith('/'):
|
||||
norm_path = os.path.normpath(clean_token)
|
||||
if re.match(r'^/[^/]*$', norm_path) or re.match(r'^/[^/]*/$', norm_path):
|
||||
raise ValueError(f"不允许使用 rm 命令删除根目录或一级目录: {clean_token}")
|
||||
validate_command_safety(command, confirmed=confirmed)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
|
||||
@@ -367,7 +342,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
asyncio.shield(wait_task), timeout=KILL_GRACE_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("命令进程强制清理超时: pid=%s", process.pid)
|
||||
logger.warning(f"命令进程强制清理超时: pid={process.pid}")
|
||||
|
||||
@staticmethod
|
||||
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
|
||||
@@ -382,7 +357,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
if isinstance(result, Exception) and not isinstance(
|
||||
result, asyncio.CancelledError
|
||||
):
|
||||
logger.debug("命令输出读取任务异常: %s", result)
|
||||
logger.debug(f"命令输出读取任务异常: {result}")
|
||||
|
||||
@staticmethod
|
||||
def _format_run_result(
|
||||
@@ -425,9 +400,10 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
command: str,
|
||||
timeout: Optional[int],
|
||||
cwd: Optional[str] = None,
|
||||
confirm_dangerous: bool = False,
|
||||
) -> str:
|
||||
"""按旧模式一次性执行命令,等待完成或超时后返回文本结果。"""
|
||||
self._validate_command(command)
|
||||
self._validate_command(command, confirmed=confirm_dangerous)
|
||||
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
|
||||
|
||||
async with _command_semaphore:
|
||||
@@ -482,27 +458,29 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
|
||||
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
|
||||
timeout: Optional[int] = 60,
|
||||
confirm_dangerous: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。"""
|
||||
normalized_action = (action or "start").strip().lower()
|
||||
logger.info(
|
||||
"执行工具: %s, action=%s, command=%s, session_id=%s",
|
||||
self.name,
|
||||
normalized_action,
|
||||
command,
|
||||
session_id,
|
||||
f"执行工具: {self.name}, action={normalized_action}, "
|
||||
f"command={command}, session_id={session_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
if normalized_action == "start":
|
||||
start_command = self._require_command(command)
|
||||
self._validate_command(start_command)
|
||||
self._validate_command(
|
||||
start_command,
|
||||
confirmed=bool(confirm_dangerous),
|
||||
)
|
||||
payload = await terminal_session_manager.start(
|
||||
command=start_command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
use_pty=use_pty,
|
||||
confirm_dangerous=bool(confirm_dangerous),
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
@@ -542,9 +520,10 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
command=self._require_command(command),
|
||||
timeout=timeout,
|
||||
cwd=cwd,
|
||||
confirm_dangerous=bool(confirm_dangerous),
|
||||
)
|
||||
|
||||
raise ValueError(f"不支持的 action: {action}")
|
||||
except Exception as err:
|
||||
logger.error("执行命令 action 失败: %s", err, exc_info=True)
|
||||
logger.error(f"执行命令 action 失败: {err}", exc_info=True)
|
||||
return self._dump({"error": str(err), "status": "error", "action": normalized_action})
|
||||
|
||||
@@ -7,9 +7,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.site import Site
|
||||
from app.db.models.siteuserdata import SiteUserData
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
|
||||
SITE_USERDATA_DETAIL_PREVIEW_LIMIT = 10
|
||||
@@ -66,118 +64,115 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 获取站点用户数据
|
||||
user_data_list = await SiteUserData.async_get_by_domain(
|
||||
db, domain=site.domain, workdate=workdate
|
||||
site_oper = SiteOper()
|
||||
site = await site_oper.async_get(site_id)
|
||||
if not site:
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
if not user_data_list:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
user_data_list = await site_oper.async_get_userdata_by_domain(
|
||||
domain=site.domain, workdate=workdate
|
||||
)
|
||||
|
||||
# 格式化用户数据
|
||||
result = {
|
||||
"success": True,
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
"data_count": len(user_data_list),
|
||||
"user_data": [],
|
||||
if not user_data_list:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 格式化用户数据
|
||||
result = {
|
||||
"success": True,
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
"data_count": len(user_data_list),
|
||||
"user_data": [],
|
||||
}
|
||||
|
||||
for user_data in user_data_list:
|
||||
# 格式化上传/下载量(转换为可读格式)
|
||||
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
|
||||
download_gb = (
|
||||
user_data.download / (1024**3) if user_data.download else 0
|
||||
)
|
||||
seeding_size_gb = (
|
||||
user_data.seeding_size / (1024**3)
|
||||
if user_data.seeding_size
|
||||
else 0
|
||||
)
|
||||
leeching_size_gb = (
|
||||
user_data.leeching_size / (1024**3)
|
||||
if user_data.leeching_size
|
||||
else 0
|
||||
)
|
||||
|
||||
seeding_preview, seeding_count, seeding_truncated = _preview_list(
|
||||
user_data.seeding_info
|
||||
)
|
||||
unread_preview, unread_count, unread_truncated = _preview_list(
|
||||
user_data.message_unread_contents
|
||||
)
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
"username": user_data.username,
|
||||
"userid": user_data.userid,
|
||||
"user_level": user_data.user_level,
|
||||
"join_at": user_data.join_at,
|
||||
"bonus": user_data.bonus,
|
||||
"upload": user_data.upload,
|
||||
"upload_gb": round(upload_gb, 2),
|
||||
"download": user_data.download,
|
||||
"download_gb": round(download_gb, 2),
|
||||
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||
"leeching": int(user_data.leeching)
|
||||
if user_data.leeching
|
||||
else 0,
|
||||
"seeding_size": user_data.seeding_size,
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info_count": seeding_count,
|
||||
"seeding_info": seeding_preview,
|
||||
"seeding_info_truncated": seeding_truncated,
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents_count": unread_count,
|
||||
"message_unread_contents": unread_preview,
|
||||
"message_unread_contents_truncated": unread_truncated,
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time,
|
||||
}
|
||||
result["user_data"].append(user_data_dict)
|
||||
|
||||
for user_data in user_data_list:
|
||||
# 格式化上传/下载量(转换为可读格式)
|
||||
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
|
||||
download_gb = (
|
||||
user_data.download / (1024**3) if user_data.download else 0
|
||||
)
|
||||
seeding_size_gb = (
|
||||
user_data.seeding_size / (1024**3)
|
||||
if user_data.seeding_size
|
||||
else 0
|
||||
)
|
||||
leeching_size_gb = (
|
||||
user_data.leeching_size / (1024**3)
|
||||
if user_data.leeching_size
|
||||
else 0
|
||||
)
|
||||
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||
if len(result["user_data"]) > 1:
|
||||
result["user_data"].sort(
|
||||
key=lambda x: (
|
||||
x.get("updated_day", ""),
|
||||
x.get("updated_time", ""),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
result["message"] = (
|
||||
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
)
|
||||
result["user_data"] = [result["user_data"][0]]
|
||||
|
||||
seeding_preview, seeding_count, seeding_truncated = _preview_list(
|
||||
user_data.seeding_info
|
||||
)
|
||||
unread_preview, unread_count, unread_truncated = _preview_list(
|
||||
user_data.message_unread_contents
|
||||
)
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
"username": user_data.username,
|
||||
"userid": user_data.userid,
|
||||
"user_level": user_data.user_level,
|
||||
"join_at": user_data.join_at,
|
||||
"bonus": user_data.bonus,
|
||||
"upload": user_data.upload,
|
||||
"upload_gb": round(upload_gb, 2),
|
||||
"download": user_data.download,
|
||||
"download_gb": round(download_gb, 2),
|
||||
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||
"leeching": int(user_data.leeching)
|
||||
if user_data.leeching
|
||||
else 0,
|
||||
"seeding_size": user_data.seeding_size,
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info_count": seeding_count,
|
||||
"seeding_info": seeding_preview,
|
||||
"seeding_info_truncated": seeding_truncated,
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents_count": unread_count,
|
||||
"message_unread_contents": unread_preview,
|
||||
"message_unread_contents_truncated": unread_truncated,
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time,
|
||||
}
|
||||
result["user_data"].append(user_data_dict)
|
||||
|
||||
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||
if len(result["user_data"]) > 1:
|
||||
result["user_data"].sort(
|
||||
key=lambda x: (
|
||||
x.get("updated_day", ""),
|
||||
x.get("updated_time", ""),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
result["message"] = (
|
||||
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
)
|
||||
result["user_data"] = [result["user_data"][0]]
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询站点用户数据失败: {str(e)}"
|
||||
|
||||
@@ -7,8 +7,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.db.subscribehistory_oper import SubscribeHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
@@ -74,88 +73,87 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
if media_type not in ["all", "movie", "tv"]:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
if name:
|
||||
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
|
||||
fetch_count = 500
|
||||
if media_type == "all":
|
||||
movie_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="movie", page=1, count=fetch_count
|
||||
)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="tv", page=1, count=fetch_count
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
all_history = list(
|
||||
await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=media_type, page=1, count=fetch_count
|
||||
)
|
||||
)
|
||||
|
||||
# 按名称过滤
|
||||
name_lower = name.lower()
|
||||
filtered_history = [
|
||||
record
|
||||
for record in all_history
|
||||
if record.name and name_lower in record.name.lower()
|
||||
]
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 名称过滤时直接返回所有匹配结果,不分页
|
||||
simplified_records = self._simplify_records(filtered_history)
|
||||
result_json = json.dumps(
|
||||
simplified_records, ensure_ascii=False, indent=2
|
||||
subscribe_history_oper = SubscribeHistoryOper()
|
||||
if name:
|
||||
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
|
||||
fetch_count = 500
|
||||
if media_type == "all":
|
||||
movie_history = await subscribe_history_oper.async_list_by_type(
|
||||
mtype="movie", page=1, count=fetch_count
|
||||
)
|
||||
return result_json
|
||||
tv_history = await subscribe_history_oper.async_list_by_type(
|
||||
mtype="tv", page=1, count=fetch_count
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
# 无名称过滤时,直接利用数据库分页
|
||||
if media_type == "all":
|
||||
movie_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="movie", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="tv", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
filtered_history = all_history
|
||||
else:
|
||||
filtered_history = list(
|
||||
await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=media_type, page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
all_history = list(
|
||||
await subscribe_history_oper.async_list_by_type(
|
||||
mtype=media_type, page=1, count=fetch_count
|
||||
)
|
||||
)
|
||||
|
||||
# 按名称过滤
|
||||
name_lower = name.lower()
|
||||
filtered_history = [
|
||||
record
|
||||
for record in all_history
|
||||
if record.name and name_lower in record.name.lower()
|
||||
]
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 分页切片
|
||||
total_count = len(filtered_history)
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_records = filtered_history[start:end]
|
||||
|
||||
if not page_records:
|
||||
return f"第 {page} 页没有数据。"
|
||||
|
||||
simplified_records = self._simplify_records(page_records)
|
||||
# 名称过滤时直接返回所有匹配结果,不分页
|
||||
simplified_records = self._simplify_records(filtered_history)
|
||||
result_json = json.dumps(
|
||||
simplified_records, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
has_more = total_count > end
|
||||
payload_msg = f"第 {page} 页,当前页 {len(simplified_records)} 条结果。"
|
||||
if has_more:
|
||||
payload_msg += (
|
||||
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
|
||||
return result_json
|
||||
else:
|
||||
# 无名称过滤时,直接利用数据库分页
|
||||
if media_type == "all":
|
||||
movie_history = await subscribe_history_oper.async_list_by_type(
|
||||
mtype="movie", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
tv_history = await subscribe_history_oper.async_list_by_type(
|
||||
mtype="tv", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
filtered_history = all_history
|
||||
else:
|
||||
filtered_history = list(
|
||||
await subscribe_history_oper.async_list_by_type(
|
||||
mtype=media_type, page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
)
|
||||
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 分页切片
|
||||
total_count = len(filtered_history)
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_records = filtered_history[start:end]
|
||||
|
||||
if not page_records:
|
||||
return f"第 {page} 页没有数据。"
|
||||
|
||||
simplified_records = self._simplify_records(page_records)
|
||||
result_json = json.dumps(
|
||||
simplified_records, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
has_more = total_count > end
|
||||
payload_msg = f"第 {page} 页,当前页 {len(simplified_records)} 条结果。"
|
||||
if has_more:
|
||||
payload_msg += (
|
||||
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
|
||||
)
|
||||
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
|
||||
return f"查询订阅历史时发生错误: {str(e)}"
|
||||
|
||||
@@ -9,8 +9,11 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._system_setting_utils import (
|
||||
SettingSpec,
|
||||
is_secret_setting_key,
|
||||
list_setting_specs,
|
||||
redact_secret_value,
|
||||
resolve_setting_spec,
|
||||
should_redact_setting,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -53,6 +56,13 @@ class QuerySystemSettingsInput(BaseModel):
|
||||
"when multiple settings are matched it returns summaries only unless this is explicitly set to true."
|
||||
),
|
||||
)
|
||||
show_secrets: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Whether to return raw secret values such as API keys, tokens, cookies, and passwords. "
|
||||
"Defaults to false; secret-like fields are redacted in returned values and previews."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class QuerySystemSettingsTool(MoviePilotTool):
|
||||
@@ -85,15 +95,18 @@ class QuerySystemSettingsTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _load_setting_value(spec: SettingSpec):
|
||||
"""读取指定设置项的当前值。"""
|
||||
if spec.source == "settings":
|
||||
return getattr(settings, spec.key)
|
||||
return SystemConfigOper().get(spec.key)
|
||||
return SystemConfigOper().get(spec.systemconfig_key)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_value(value) -> dict:
|
||||
def _summarize_value(value, *, redacted: bool = False) -> dict:
|
||||
"""生成设置值摘要,避免列表和字典默认输出过长。"""
|
||||
summary = {
|
||||
"has_value": value is not None,
|
||||
"value_type": type(value).__name__,
|
||||
"redacted": redacted,
|
||||
}
|
||||
if isinstance(value, list):
|
||||
summary["item_count"] = len(value)
|
||||
@@ -122,14 +135,12 @@ class QuerySystemSettingsTool(MoviePilotTool):
|
||||
group: Optional[str] = "all",
|
||||
keyword: Optional[str] = None,
|
||||
include_values: Optional[bool] = None,
|
||||
show_secrets: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
"执行工具: %s, setting_key=%s, group=%s, keyword=%s",
|
||||
self.name,
|
||||
setting_key,
|
||||
group,
|
||||
keyword,
|
||||
f"执行工具: {self.name}, setting_key={setting_key}, "
|
||||
f"group={group}, keyword={keyword}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -158,18 +169,30 @@ class QuerySystemSettingsTool(MoviePilotTool):
|
||||
should_include_values = (
|
||||
include_values if include_values is not None else len(specs) == 1
|
||||
)
|
||||
allow_secret_values = bool(show_secrets) and await self.is_admin_user()
|
||||
settings_payload = []
|
||||
for spec in specs:
|
||||
value = self._load_setting_value(spec)
|
||||
should_redact = (
|
||||
should_redact_setting(spec, value) and not allow_secret_values
|
||||
)
|
||||
response_value = (
|
||||
redact_secret_value(
|
||||
value,
|
||||
redact_scalar=is_secret_setting_key(spec.key),
|
||||
)
|
||||
if should_redact
|
||||
else value
|
||||
)
|
||||
item = {
|
||||
"setting_key": spec.key,
|
||||
"source": spec.source,
|
||||
"group": spec.group,
|
||||
"label": spec.label,
|
||||
}
|
||||
item.update(self._summarize_value(value))
|
||||
item.update(self._summarize_value(response_value, redacted=should_redact))
|
||||
if should_include_values:
|
||||
item["value"] = value
|
||||
item["value"] = response_value
|
||||
settings_payload.append(item)
|
||||
|
||||
return json.dumps(
|
||||
@@ -177,6 +200,7 @@ class QuerySystemSettingsTool(MoviePilotTool):
|
||||
"success": True,
|
||||
"matched_count": len(settings_payload),
|
||||
"include_values": should_include_values,
|
||||
"show_secrets": allow_secret_values,
|
||||
"settings": settings_payload,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
|
||||
@@ -7,8 +7,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
from app.utils.jieba import cut as jieba_cut
|
||||
@@ -70,70 +69,69 @@ class QueryTransferHistoryTool(MoviePilotTool):
|
||||
# 每页固定 30 条,与工具说明保持一致,避免整理路径等字段撑大上下文。
|
||||
count = 30
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 处理标题搜索
|
||||
if title:
|
||||
# 使用统一分词封装处理标题,便于替换底层实现。
|
||||
words = jieba_cut(title, HMM=False)
|
||||
title_search = "%".join(words)
|
||||
# 查询记录
|
||||
result = await TransferHistory.async_list_by_title(
|
||||
db, title=title_search, page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await TransferHistory.async_count_by_title(
|
||||
db, title=title_search, status=status_bool
|
||||
)
|
||||
else:
|
||||
# 查询所有记录
|
||||
result = await TransferHistory.async_list_by_page(
|
||||
db, page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await TransferHistory.async_count(db, status=status_bool)
|
||||
transferhis = TransferHistoryOper()
|
||||
# 处理标题搜索
|
||||
if title:
|
||||
# 使用统一分词封装处理标题,便于替换底层实现。
|
||||
words = jieba_cut(title, HMM=False)
|
||||
title_search = "%".join(words)
|
||||
# 查询记录
|
||||
result = await transferhis.async_list_by_title(
|
||||
title=title_search, page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await transferhis.async_count_by_title(
|
||||
title=title_search, status=status_bool
|
||||
)
|
||||
else:
|
||||
# 查询所有记录
|
||||
result = await transferhis.async_list_by_page(
|
||||
page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await transferhis.async_count(status=status_bool)
|
||||
|
||||
if not result:
|
||||
return "未找到相关整理历史记录"
|
||||
if not result:
|
||||
return "未找到相关整理历史记录"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in result:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"title": record.title,
|
||||
"year": record.year,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"category": record.category,
|
||||
"seasons": record.seasons,
|
||||
"episodes": record.episodes,
|
||||
"src": record.src,
|
||||
"dest": record.dest,
|
||||
"mode": record.mode,
|
||||
"status": "成功" if record.status else "失败",
|
||||
"date": record.date,
|
||||
"downloader": record.downloader,
|
||||
"download_hash": record.download_hash
|
||||
}
|
||||
# 如果失败,添加错误信息
|
||||
if not record.status and record.errmsg:
|
||||
simplified["errmsg"] = record.errmsg
|
||||
# 添加媒体ID信息(如果有)
|
||||
if record.tmdbid:
|
||||
simplified["tmdbid"] = record.tmdbid
|
||||
if record.imdbid:
|
||||
simplified["imdbid"] = record.imdbid
|
||||
if record.doubanid:
|
||||
simplified["doubanid"] = record.doubanid
|
||||
simplified_records.append(simplified)
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in result:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"title": record.title,
|
||||
"year": record.year,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"category": record.category,
|
||||
"seasons": record.seasons,
|
||||
"episodes": record.episodes,
|
||||
"src": record.src,
|
||||
"dest": record.dest,
|
||||
"mode": record.mode,
|
||||
"status": "成功" if record.status else "失败",
|
||||
"date": record.date,
|
||||
"downloader": record.downloader,
|
||||
"download_hash": record.download_hash
|
||||
}
|
||||
# 如果失败,添加错误信息
|
||||
if not record.status and record.errmsg:
|
||||
simplified["errmsg"] = record.errmsg
|
||||
# 添加媒体ID信息(如果有)
|
||||
if record.tmdbid:
|
||||
simplified["tmdbid"] = record.tmdbid
|
||||
if record.imdbid:
|
||||
simplified["imdbid"] = record.imdbid
|
||||
if record.doubanid:
|
||||
simplified["doubanid"] = record.doubanid
|
||||
simplified_records.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
|
||||
# 计算总页数
|
||||
total_pages = (total + count - 1) // count if total > 0 else 1
|
||||
# 计算总页数
|
||||
total_pages = (total + count - 1) // count if total > 0 else 1
|
||||
|
||||
# 构建分页信息
|
||||
pagination_info = f"第 {page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)"
|
||||
# 构建分页信息
|
||||
pagination_info = f"第 {page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)"
|
||||
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询整理历史记录失败: {e}", exc_info=True)
|
||||
return f"查询整理历史记录时发生错误: {str(e)}"
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
|
||||
@@ -56,75 +55,73 @@ class QueryWorkflowsTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: state={state}, name={name}, trigger_type={trigger_type}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
workflows = await workflow_oper.async_list()
|
||||
|
||||
# 过滤工作流
|
||||
filtered_workflows = []
|
||||
for wf in workflows:
|
||||
# 按状态过滤
|
||||
if state != "all" and wf.state != state:
|
||||
workflow_oper = WorkflowOper()
|
||||
workflows = await workflow_oper.async_list()
|
||||
|
||||
# 过滤工作流
|
||||
filtered_workflows = []
|
||||
for wf in workflows:
|
||||
# 按状态过滤
|
||||
if state != "all" and wf.state != state:
|
||||
continue
|
||||
|
||||
# 按触发类型过滤
|
||||
if trigger_type != "all":
|
||||
if trigger_type == "timer" and wf.trigger_type not in ["timer", None]:
|
||||
continue
|
||||
|
||||
# 按触发类型过滤
|
||||
if trigger_type != "all":
|
||||
if trigger_type == "timer" and wf.trigger_type not in ["timer", None]:
|
||||
continue
|
||||
elif trigger_type == "event" and wf.trigger_type != "event":
|
||||
continue
|
||||
elif trigger_type == "manual" and wf.trigger_type != "manual":
|
||||
continue
|
||||
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and wf.name and name.lower() not in wf.name.lower():
|
||||
elif trigger_type == "event" and wf.trigger_type != "event":
|
||||
continue
|
||||
|
||||
filtered_workflows.append(wf)
|
||||
|
||||
if not filtered_workflows:
|
||||
return "未找到相关工作流"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_workflows = []
|
||||
for wf in filtered_workflows:
|
||||
# 状态说明
|
||||
state_map = {
|
||||
"W": "等待",
|
||||
"R": "运行中",
|
||||
"P": "暂停",
|
||||
"S": "成功",
|
||||
"F": "失败"
|
||||
}
|
||||
state_desc = state_map.get(wf.state, wf.state)
|
||||
|
||||
# 触发类型说明
|
||||
trigger_type_map = {
|
||||
"timer": "定时触发",
|
||||
"event": "事件触发",
|
||||
"manual": "手动触发"
|
||||
}
|
||||
trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发")
|
||||
|
||||
simplified = {
|
||||
"id": wf.id,
|
||||
"name": wf.name,
|
||||
"description": wf.description,
|
||||
"trigger_type": trigger_type_desc,
|
||||
"state": state_desc,
|
||||
"run_count": wf.run_count,
|
||||
"timer": wf.timer,
|
||||
"event_type": wf.event_type,
|
||||
"add_time": wf.add_time,
|
||||
"last_time": wf.last_time,
|
||||
"current_action": wf.current_action
|
||||
}
|
||||
# wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。
|
||||
simplified_workflows.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
elif trigger_type == "manual" and wf.trigger_type != "manual":
|
||||
continue
|
||||
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and wf.name and name.lower() not in wf.name.lower():
|
||||
continue
|
||||
|
||||
filtered_workflows.append(wf)
|
||||
|
||||
if not filtered_workflows:
|
||||
return "未找到相关工作流"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_workflows = []
|
||||
for wf in filtered_workflows:
|
||||
# 状态说明
|
||||
state_map = {
|
||||
"W": "等待",
|
||||
"R": "运行中",
|
||||
"P": "暂停",
|
||||
"S": "成功",
|
||||
"F": "失败"
|
||||
}
|
||||
state_desc = state_map.get(wf.state, wf.state)
|
||||
|
||||
# 触发类型说明
|
||||
trigger_type_map = {
|
||||
"timer": "定时触发",
|
||||
"event": "事件触发",
|
||||
"manual": "手动触发"
|
||||
}
|
||||
trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发")
|
||||
|
||||
simplified = {
|
||||
"id": wf.id,
|
||||
"name": wf.name,
|
||||
"description": wf.description,
|
||||
"trigger_type": trigger_type_desc,
|
||||
"state": state_desc,
|
||||
"run_count": wf.run_count,
|
||||
"timer": wf.timer,
|
||||
"event_type": wf.event_type,
|
||||
"add_time": wf.add_time,
|
||||
"last_time": wf.last_time,
|
||||
"current_action": wf.current_action
|
||||
}
|
||||
# wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。
|
||||
simplified_workflows.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询工作流失败: {e}", exc_info=True)
|
||||
return f"查询工作流时发生错误: {str(e)}"
|
||||
|
||||
@@ -15,7 +15,7 @@ MAX_READ_SIZE = 50 * 1024
|
||||
|
||||
|
||||
class ReadFileInput(BaseModel):
|
||||
"""Input parameters for read file tool"""
|
||||
"""文件读取工具的输入参数模型。"""
|
||||
file_path: str = Field(..., description="The absolute path of the file to read")
|
||||
start_line: Optional[int] = Field(None, description="The starting line number (1-based, inclusive). If not provided, reading starts from the beginning of the file.")
|
||||
end_line: Optional[int] = Field(None, description="The ending line number (1-based, inclusive). If not provided, reading goes until the end of the file.")
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
|
||||
@@ -65,26 +64,23 @@ class RunWorkflowTool(MoviePilotTool):
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
workflow = await workflow_oper.async_get(workflow_id)
|
||||
workflow = await WorkflowOper().async_get(workflow_id)
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
# 工作流执行链路包含大量同步步骤,统一放到 workflow 线程池。
|
||||
state, errmsg = await self.run_blocking(
|
||||
"workflow",
|
||||
self._run_workflow_sync,
|
||||
workflow.id,
|
||||
from_begin,
|
||||
)
|
||||
# 工作流执行链路包含大量同步步骤,统一放到 workflow 线程池。
|
||||
state, errmsg = await self.run_blocking(
|
||||
"workflow",
|
||||
self._run_workflow_sync,
|
||||
workflow.id,
|
||||
from_begin,
|
||||
)
|
||||
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
return f"工作流执行成功:{workflow.name} (ID: {workflow.id})"
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
return f"工作流执行成功:{workflow.name} (ID: {workflow.id})"
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流失败: {e}", exc_info=True)
|
||||
return f"执行工作流时发生错误: {str(e)}"
|
||||
|
||||
@@ -8,8 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.site import Site
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.string import StringUtils
|
||||
@@ -127,108 +126,106 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
if ua is not None:
|
||||
site_dict["ua"] = ua
|
||||
if apikey is not None:
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
if filter is not None:
|
||||
site_dict["filter"] = filter
|
||||
if note is not None:
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
if limit_count is not None:
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新站点
|
||||
await site.async_update(db, site_dict)
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await Site.async_get(db, site_id)
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SiteUpdated,
|
||||
{"domain": updated_site.domain if updated_site else site.domain},
|
||||
site_oper = SiteOper()
|
||||
site = await site_oper.async_get(site_id)
|
||||
if not site:
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys()),
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
if ua is not None:
|
||||
site_dict["ua"] = ua
|
||||
if apikey is not None:
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
if filter is not None:
|
||||
site_dict["filter"] = filter
|
||||
if note is not None:
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
if limit_count is not None:
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新站点
|
||||
await site_oper.async_update(site_id, site_dict)
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await site_oper.async_get(site_id)
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SiteUpdated,
|
||||
{"domain": updated_site.domain if updated_site else site.domain},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys()),
|
||||
}
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
"name": updated_site.name,
|
||||
"domain": updated_site.domain,
|
||||
"url": updated_site.url,
|
||||
"pri": updated_site.pri,
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout,
|
||||
}
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
"name": updated_site.name,
|
||||
"domain": updated_site.domain,
|
||||
"url": updated_site.url,
|
||||
"pri": updated_site.pri,
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout,
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新站点失败: {str(e)}"
|
||||
|
||||
@@ -8,8 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
|
||||
@@ -157,149 +156,147 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取订阅
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
if year is not None:
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (
|
||||
total_episode - (subscribe.total_episode or 0)
|
||||
)
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
if resolution is not None:
|
||||
subscribe_dict["resolution"] = resolution
|
||||
if effect is not None:
|
||||
subscribe_dict["effect"] = effect
|
||||
if include is not None:
|
||||
subscribe_dict["include"] = include
|
||||
if exclude is not None:
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
if downloader is not None:
|
||||
subscribe_dict["downloader"] = downloader
|
||||
if save_path is not None:
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
if best_version_full is not None:
|
||||
subscribe_dict["best_version_full"] = best_version_full
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
if media_category is not None:
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
if year is not None:
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (
|
||||
total_episode - (subscribe.total_episode or 0)
|
||||
)
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
if resolution is not None:
|
||||
subscribe_dict["resolution"] = resolution
|
||||
if effect is not None:
|
||||
subscribe_dict["effect"] = effect
|
||||
if include is not None:
|
||||
subscribe_dict["include"] = include
|
||||
if exclude is not None:
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
if downloader is not None:
|
||||
subscribe_dict["downloader"] = downloader
|
||||
if save_path is not None:
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
if best_version_full is not None:
|
||||
subscribe_dict["best_version_full"] = best_version_full
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
if media_category is not None:
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe_oper.async_update(subscribe_id, subscribe_dict)
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys()),
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys()),
|
||||
}
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
"name": updated_subscribe.name,
|
||||
"year": updated_subscribe.year,
|
||||
"type": updated_subscribe.type,
|
||||
"season": updated_subscribe.season,
|
||||
"state": updated_subscribe.state,
|
||||
"total_episode": updated_subscribe.total_episode,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect,
|
||||
}
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
"name": updated_subscribe.name,
|
||||
"year": updated_subscribe.year,
|
||||
"type": updated_subscribe.type,
|
||||
"season": updated_subscribe.season,
|
||||
"state": updated_subscribe.state,
|
||||
"total_episode": updated_subscribe.total_episode,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect,
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新订阅失败: {str(e)}"
|
||||
|
||||
@@ -11,7 +11,10 @@ from app.agent.tools.tags import ToolTag
|
||||
from app.agent.tools.impl._system_setting_utils import (
|
||||
SettingSpec,
|
||||
get_default_list_match_field,
|
||||
is_secret_setting_key,
|
||||
redact_secret_value,
|
||||
resolve_setting_spec,
|
||||
should_redact_setting,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
@@ -102,12 +105,14 @@ class UpdateSystemSettingsTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _load_setting_value(spec: SettingSpec):
|
||||
"""读取指定设置项的当前值。"""
|
||||
if spec.source == "settings":
|
||||
return getattr(settings, spec.key)
|
||||
return SystemConfigOper().get(spec.key)
|
||||
return SystemConfigOper().get(spec.systemconfig_key)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_systemconfig_value(value: Any):
|
||||
"""规范化写入 SystemConfig 的空列表值。"""
|
||||
if isinstance(value, list):
|
||||
filtered = [item for item in value if item is not None]
|
||||
return filtered or None
|
||||
@@ -221,10 +226,7 @@ class UpdateSystemSettingsTool(MoviePilotTool):
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
"执行工具: %s, setting_key=%s, operation=%s",
|
||||
self.name,
|
||||
setting_key,
|
||||
operation,
|
||||
f"执行工具: {self.name}, setting_key={setting_key}, operation={operation}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -266,7 +268,10 @@ class UpdateSystemSettingsTool(MoviePilotTool):
|
||||
else:
|
||||
normalized_value = self._normalize_systemconfig_value(next_value)
|
||||
event_value = normalized_value
|
||||
success = await SystemConfigOper().async_set(spec.key, normalized_value)
|
||||
success = await SystemConfigOper().async_set(
|
||||
spec.systemconfig_key,
|
||||
normalized_value,
|
||||
)
|
||||
changed = success is True
|
||||
|
||||
if changed:
|
||||
@@ -280,6 +285,26 @@ class UpdateSystemSettingsTool(MoviePilotTool):
|
||||
)
|
||||
|
||||
saved_value = self._load_setting_value(spec)
|
||||
redact_values = (
|
||||
should_redact_setting(spec, saved_value)
|
||||
or should_redact_setting(spec, current_value)
|
||||
)
|
||||
response_previous_value = (
|
||||
redact_secret_value(
|
||||
current_value,
|
||||
redact_scalar=is_secret_setting_key(spec.key),
|
||||
)
|
||||
if redact_values
|
||||
else current_value
|
||||
)
|
||||
response_saved_value = (
|
||||
redact_secret_value(
|
||||
saved_value,
|
||||
redact_scalar=is_secret_setting_key(spec.key),
|
||||
)
|
||||
if redact_values
|
||||
else saved_value
|
||||
)
|
||||
if not changed and not message:
|
||||
message = "配置值未发生变化"
|
||||
|
||||
@@ -295,8 +320,9 @@ class UpdateSystemSettingsTool(MoviePilotTool):
|
||||
"group": spec.group,
|
||||
"label": spec.label,
|
||||
},
|
||||
"previous_value": current_value,
|
||||
"saved_value": saved_value,
|
||||
"values_redacted": redact_values,
|
||||
"previous_value": response_previous_value,
|
||||
"saved_value": response_saved_value,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.log import logger
|
||||
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
"""Input parameters for write file tool"""
|
||||
"""文件写入工具的输入参数模型。"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to write")
|
||||
content: str = Field(..., description="The content to write into the file")
|
||||
@@ -26,7 +26,7 @@ class WriteFileTool(MoviePilotTool):
|
||||
]
|
||||
description: str = (
|
||||
"Write full content to a local text file. Non-admin users can only write "
|
||||
"inside the MoviePilot config, Agent memory/activity, and log directories."
|
||||
"inside the MoviePilot Agent config and log directories."
|
||||
)
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
|
||||
|
||||
@@ -114,6 +114,12 @@ class DownloadHistoryOper(DbOper):
|
||||
"""
|
||||
return DownloadHistory.list_by_page(self._db, page, count)
|
||||
|
||||
async def async_delete_history(self, historyid: int):
|
||||
"""
|
||||
异步删除下载记录。
|
||||
"""
|
||||
await DownloadHistory.async_delete(self._db, historyid)
|
||||
|
||||
def truncate(self):
|
||||
"""
|
||||
清空下载记录
|
||||
|
||||
@@ -79,6 +79,15 @@ class SiteOper(DbOper):
|
||||
site.update(self._db, payload)
|
||||
return site
|
||||
|
||||
async def async_update(self, sid: int, payload: dict) -> Site:
|
||||
"""
|
||||
异步更新站点。
|
||||
"""
|
||||
site = await self.async_get(sid)
|
||||
if site:
|
||||
await site.async_update(self._db, payload)
|
||||
return site
|
||||
|
||||
def get_by_domain(self, domain: str) -> Site:
|
||||
"""
|
||||
按域名获取站点
|
||||
@@ -170,6 +179,16 @@ class SiteOper(DbOper):
|
||||
"""
|
||||
return SiteUserData.get_by_domain(self._db, domain=domain, workdate=workdate)
|
||||
|
||||
async def async_get_userdata_by_domain(
|
||||
self, domain: str, workdate: Optional[str] = None
|
||||
) -> List[SiteUserData]:
|
||||
"""
|
||||
异步获取站点用户数据。
|
||||
"""
|
||||
return await SiteUserData.async_get_by_domain(
|
||||
self._db, domain=domain, workdate=workdate
|
||||
)
|
||||
|
||||
def get_userdata_by_date(self, date: str) -> List[SiteUserData]:
|
||||
"""
|
||||
获取站点用户数据
|
||||
|
||||
@@ -169,6 +169,22 @@ class SubscribeOper(DbOper):
|
||||
"""
|
||||
await Subscribe.async_delete(self._db, rid=sid)
|
||||
|
||||
async def async_update(self, sid: int, payload: dict) -> Subscribe:
|
||||
"""
|
||||
异步更新订阅。
|
||||
"""
|
||||
subscribe = await self.async_get(sid)
|
||||
if subscribe:
|
||||
payload = _normalize_integer_flags(payload)
|
||||
await subscribe.async_update(self._db, payload)
|
||||
return subscribe
|
||||
|
||||
async def async_update_filter_groups(self, sid: int, filter_groups: list) -> Subscribe:
|
||||
"""
|
||||
异步更新订阅使用的过滤规则组。
|
||||
"""
|
||||
return await self.async_update(sid, {"filter_groups": filter_groups})
|
||||
|
||||
def update(self, sid: int, payload: dict) -> Subscribe:
|
||||
"""
|
||||
更新订阅
|
||||
|
||||
26
app/db/subscribehistory_oper.py
Normal file
26
app/db/subscribehistory_oper.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.db import DbOper
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
|
||||
|
||||
class SubscribeHistoryOper(DbOper):
|
||||
"""
|
||||
订阅历史管理。
|
||||
"""
|
||||
|
||||
async def async_list_by_type(
|
||||
self,
|
||||
mtype: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
) -> List[SubscribeHistory]:
|
||||
"""
|
||||
异步按媒体类型分页查询订阅历史。
|
||||
"""
|
||||
return await SubscribeHistory.async_list_by_type(
|
||||
self._db,
|
||||
mtype=mtype,
|
||||
page=page,
|
||||
count=count,
|
||||
)
|
||||
@@ -26,6 +26,51 @@ class TransferHistoryOper(DbOper):
|
||||
"""
|
||||
return await TransferHistory.async_get(self._db, historyid)
|
||||
|
||||
async def async_list_by_title(
|
||||
self,
|
||||
title: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
status: Optional[bool] = None,
|
||||
) -> List[TransferHistory]:
|
||||
"""
|
||||
异步按标题分页查询转移记录。
|
||||
"""
|
||||
return await TransferHistory.async_list_by_title(
|
||||
self._db, title=title, page=page, count=count, status=status
|
||||
)
|
||||
|
||||
async def async_list_by_page(
|
||||
self,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
status: Optional[bool] = None,
|
||||
) -> List[TransferHistory]:
|
||||
"""
|
||||
异步分页查询转移记录。
|
||||
"""
|
||||
return await TransferHistory.async_list_by_page(
|
||||
self._db, page=page, count=count, status=status
|
||||
)
|
||||
|
||||
async def async_count(self, status: Optional[bool] = None) -> int:
|
||||
"""
|
||||
异步统计转移记录数量。
|
||||
"""
|
||||
return await TransferHistory.async_count(self._db, status=status)
|
||||
|
||||
async def async_count_by_title(
|
||||
self,
|
||||
title: str,
|
||||
status: Optional[bool] = None,
|
||||
) -> int:
|
||||
"""
|
||||
异步按标题统计转移记录数量。
|
||||
"""
|
||||
return await TransferHistory.async_count_by_title(
|
||||
self._db, title=title, status=status
|
||||
)
|
||||
|
||||
def get_by_title(self, title: str) -> List[TransferHistory]:
|
||||
"""
|
||||
按标题查询转移记录
|
||||
|
||||
Reference in New Issue
Block a user