feat: add batch AI re-organize for transfer history and search result recommendation

- Implement batch AI re-organize endpoint for transfer history with progress tracking
- Add batch_manual_transfer_redo system task template and prompt generation
- Refactor agent_manager to support generic background prompt execution
- Add AIRecommendChain for search result recommendation using agent background prompt
- Update search endpoints to use new AIRecommendChain and remove legacy code
- Enhance test cases for batch manual transfer redo
- Minor code cleanup and style fixes
This commit is contained in:
jxxghp
2026-04-29 22:16:04 +08:00
parent b6f0ef99ab
commit 460d716512
16 changed files with 821 additions and 631 deletions

View File

@@ -32,7 +32,6 @@ from app.agent.runtime import agent_runtime_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.db.transferhistory_oper import TransferHistoryOper
from app.helper.llm import LLMHelper
from app.log import logger
from app.schemas import Notification, NotificationType
@@ -731,21 +730,12 @@ class AgentManager:
同一会话的消息按顺序排队处理,不同会话之间互不影响。
"""
# 批量重试整理的等待时间同一批次内的失败记录会合并为一次agent调用
RETRY_TRANSFER_DEBOUNCE_SECONDS = 300
def __init__(self):
self.active_agents: Dict[str, MoviePilotAgent] = {}
# 每个会话的消息队列
self._session_queues: Dict[str, asyncio.Queue] = {}
# 每个会话的worker任务
self._session_workers: Dict[str, asyncio.Task] = {}
# 重试整理的 debounce 缓冲区: group_key -> List[history_id]
self._retry_transfer_buffer: Dict[str, List[int]] = {}
# 重试整理的 debounce 定时器: group_key -> asyncio.TimerHandle
self._retry_transfer_timers: Dict[str, asyncio.TimerHandle] = {}
# 重试整理缓冲区锁
self._retry_transfer_lock = asyncio.Lock()
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""获取会话当前模型与 token 使用状态。"""
@@ -790,11 +780,6 @@ class AgentManager:
关闭管理器
"""
await memory_manager.close()
# 取消所有重试整理的延迟定时器
for timer in self._retry_transfer_timers.values():
timer.cancel()
self._retry_transfer_timers.clear()
self._retry_transfer_buffer.clear()
# 取消所有会话worker
for task in self._session_workers.values():
task.cancel()
@@ -995,67 +980,40 @@ class AgentManager:
memory_manager.clear_memory(session_id, user_id)
logger.info(f"会话 {session_id} 的记忆已清空")
@staticmethod
async def run_background_prompt(
message: str,
session_prefix: str = "__agent_background",
output_callback: Optional[Callable[[str], None]] = None,
suppress_user_reply: bool = False,
) -> None:
"""
以独立后台会话执行一段 prompt。
"""
session_id = f"{session_prefix}_{uuid.uuid4().hex[:8]}__"
user_id = SYSTEM_INTERNAL_USER_ID
agent = MoviePilotAgent(
session_id=session_id,
user_id=user_id,
channel=None,
source=None,
username=settings.SUPERUSER,
)
agent.output_callback = output_callback
agent.force_streaming = bool(output_callback)
agent.suppress_user_reply = suppress_user_reply
try:
await agent.process(message)
finally:
await agent.cleanup()
memory_manager.clear_memory(session_id, user_id)
@staticmethod
def _build_heartbeat_prompt() -> str:
"""使用程序内置 System Tasks 定义构建心跳任务提示词。"""
return prompt_manager.render_system_task_message("heartbeat")
@staticmethod
def _build_retry_transfer_template_context(
history_ids: list[int],
) -> tuple[str, dict[str, int | str]]:
"""仅负责把失败重试任务的动态数据映射成模板变量。"""
is_batch = len(history_ids) > 1
task_type = (
"batch_transfer_failed_retry" if is_batch else "transfer_failed_retry"
)
template_context: dict[str, int | str] = {
"history_ids_csv": ", ".join(str(item) for item in history_ids),
"history_count": len(history_ids),
}
if not is_batch:
template_context["history_id"] = history_ids[0]
return task_type, template_context
@staticmethod
def _build_retry_transfer_prompt(
history_ids: list[int],
) -> str:
"""根据失败记录数量构建统一的重试整理后台任务提示词。"""
task_type, template_context = AgentManager._build_retry_transfer_template_context(
history_ids
)
return prompt_manager.render_system_task_message(
task_type,
template_context=template_context,
)
@staticmethod
def _build_manual_redo_template_context(history) -> dict[str, int | str]:
"""仅负责把整理历史对象映射成 System Tasks 需要的模板变量。"""
src_fileitem = history.src_fileitem or {}
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
source_path = source_path or history.src or ""
season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip()
# 这里故意只做数据整形,具体行为定义全部交给内置 System Tasks YAML。
return {
"history_id": history.id,
"current_status": "success" if history.status else "failed",
"recognized_title": history.title or "unknown",
"media_type": history.type or "unknown",
"category": history.category or "unknown",
"year": history.year or "unknown",
"season_episode": season_episode or "unknown",
"source_path": source_path or "unknown",
"source_storage": history.src_storage or "local",
"destination_path": history.dest or "unknown",
"destination_storage": history.dest_storage or "unknown",
"transfer_mode": history.mode or "unknown",
"tmdbid": history.tmdbid or "none",
"doubanid": history.doubanid or "none",
"error_message": history.errmsg or "none",
}
async def heartbeat_check_jobs(self):
"""
心跳唤醒检查并执行待处理的定时任务Jobs
@@ -1097,135 +1055,6 @@ class AgentManager:
except Exception as e:
logger.error(f"智能体心跳唤醒失败: {e}")
async def retry_failed_transfer(self, history_id: int, group_key: str = ""):
"""
触发智能体重新整理失败的历史记录。
由文件整理模块在检测到整理失败后调用。
同一 group_key 的失败记录会在缓冲期内合并为一次agent调用避免重复浪费token。
:param history_id: 失败的整理历史记录ID
:param group_key: 分组键相同key的记录会被合并处理如download_hash、源目录等
"""
if not group_key:
group_key = f"_default_{history_id}"
async with self._retry_transfer_lock:
# 将 history_id 加入缓冲区
if group_key not in self._retry_transfer_buffer:
self._retry_transfer_buffer[group_key] = []
if history_id not in self._retry_transfer_buffer[group_key]:
self._retry_transfer_buffer[group_key].append(history_id)
logger.info(
f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 "
f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)"
)
# 取消该分组的旧定时器
if group_key in self._retry_transfer_timers:
self._retry_transfer_timers[group_key].cancel()
# 设置新的延迟定时器
loop = asyncio.get_running_loop()
self._retry_transfer_timers[group_key] = loop.call_later(
self.RETRY_TRANSFER_DEBOUNCE_SECONDS,
lambda gk=group_key: asyncio.ensure_future(
self._flush_retry_transfer(gk)
),
)
async def _flush_retry_transfer(self, group_key: str):
"""
延迟定时器到期后,取出该分组的所有 history_id 并合并为一次agent调用。
"""
async with self._retry_transfer_lock:
history_ids = self._retry_transfer_buffer.pop(group_key, [])
self._retry_transfer_timers.pop(group_key, None)
if not history_ids:
return
session_id = f"__agent_retry_transfer_batch_{uuid.uuid4().hex[:8]}__"
user_id = SYSTEM_INTERNAL_USER_ID
ids_str = ", ".join(str(i) for i in history_ids)
logger.info(
f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})"
)
retry_message = self._build_retry_transfer_prompt(history_ids)
try:
await self.process_message(
session_id=session_id,
user_id=user_id,
message=retry_message,
channel=None,
source=None,
username=settings.SUPERUSER,
)
# 等待消息队列处理完成
if session_id in self._session_queues:
await self._session_queues[session_id].join()
# 等待worker结束
if session_id in self._session_workers:
try:
await self._session_workers[session_id]
except asyncio.CancelledError:
pass
logger.info(
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"
)
# 用完即弃,清理资源
await self.clear_session(session_id, user_id)
except Exception as e:
logger.error(
f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {e}"
)
@staticmethod
def _build_manual_redo_prompt(history) -> str:
"""
构建手动 AI 整理提示词。
"""
return prompt_manager.render_system_task_message(
"manual_transfer_redo",
template_context=AgentManager._build_manual_redo_template_context(history),
)
async def manual_redo_transfer(
self,
history_id: int,
output_callback: Optional[Callable[[str], None]] = None,
) -> None:
"""
手动触发单条历史记录的 AI 整理。
"""
session_id = f"__agent_manual_redo_{history_id}_{uuid.uuid4().hex[:8]}__"
user_id = SYSTEM_INTERNAL_USER_ID
agent = MoviePilotAgent(
session_id=session_id,
user_id=user_id,
channel=None,
source=None,
username=settings.SUPERUSER,
)
agent.output_callback = output_callback
agent.force_streaming = True
agent.suppress_user_reply = True
try:
history = TransferHistoryOper().get(history_id)
if not history:
raise ValueError(f"整理记录不存在: {history_id}")
await agent.process(self._build_manual_redo_prompt(history))
finally:
await agent.cleanup()
memory_manager.clear_memory(session_id, user_id)
# 全局智能体管理器实例
agent_manager = AgentManager()

View File

@@ -95,3 +95,45 @@ task_types:
- "Do NOT reorganize blindly when media identity is uncertain."
- "If the previous record was successful but obviously identified as the wrong media, still use the tool-based flow above instead of `/redo`."
- "Keep the final response short and focused on outcome."
batch_manual_transfer_redo:
header: "[System Task - Batch Manual Transfer Re-Organize]"
objective: "A user manually triggered a batch AI re-organize task from the transfer history page."
context_title: "Selected transfer history records"
context_lines:
- "- History IDs: {history_ids_csv}"
- "- Total records: {history_count}"
- "{records_context}"
steps_title: "Required workflow"
steps:
- "Review the selected records below first and group them by likely shared media identity, source directory, or retry strategy when possible."
- "Use the provided record context as the primary source of truth. Call `query_transfer_history` only when you need extra confirmation."
- "For each group, decide whether the current recognition is trustworthy."
- "If multiple records clearly belong to the same movie or series, identify the media once with `recognize_media` or `search_media`, then reuse that result for the related records."
- "If a source file no longer exists or cannot be safely processed, skip that record and note the reason."
- "Before re-organizing a record, delete the old transfer history record with `delete_transfer_history` so the system will not skip the source file."
- "Then use `transfer_file` to organize the source path directly."
- "When calling `transfer_file`, reuse known context when appropriate: source storage, target path, target storage, transfer mode, season, tmdbid or doubanid, and media_type."
- "If a record is already correct and no re-organize is needed, do not perform destructive actions; simply mark it as skipped."
- "Report only the aggregate outcome, including how many records succeeded, skipped, and failed."
task_rules:
- "Do NOT assume every selected record belongs to the same media."
- "When several records obviously share the same media identity, avoid repeated `recognize_media` or `search_media` calls."
- "Process every selected record exactly once."
- "Keep the final response short and focused on the aggregate outcome."
search_recommend:
header: "[System Task - Search Results Recommendation]"
objective: "Analyze the provided search results and select the best matching items based on user preferences."
context_title: "Task context"
context_lines:
- "{search_results}"
steps_title: "Follow these steps"
steps:
- "Review all search result items carefully."
- "Evaluate each item based on the user preference criteria."
- "Select the top items that best match the preferences."
- "Return ONLY a JSON array of item indices."
task_rules:
- "Return ONLY a JSON array of index numbers, e.g., [0, 3, 1]."
- "Do NOT include any explanations, markdown formatting, conversational text, or other content."
- "Do NOT call any tools. Simply analyze and return the JSON result directly."
- "Respond in JSON format only."

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.log import logger
from app.scheduler import Scheduler
class QuerySchedulersInput(BaseModel):
@@ -27,6 +26,8 @@ class QuerySchedulersTool(MoviePilotTool):
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")
try:
from app.scheduler import Scheduler
scheduler = Scheduler()
schedulers = scheduler.list()
if schedulers:

View File

@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.log import logger
from app.scheduler import Scheduler
class RunSchedulerInput(BaseModel):
@@ -36,6 +35,8 @@ class RunSchedulerTool(MoviePilotTool):
@staticmethod
def _run_scheduler_sync(job_id: str) -> tuple[bool, str]:
"""同步触发定时服务,避免调度器扫描阻塞事件循环。"""
from app.scheduler import Scheduler
scheduler = Scheduler()
for scheduler_item in scheduler.list():
if scheduler_item.id == job_id:

View File

@@ -6,7 +6,6 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.transfer import TransferChain
from app.log import logger
from app.schemas import FileItem, MediaType
@@ -124,6 +123,8 @@ class TransferFileTool(MoviePilotTool):
if not media_type_enum:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
from app.chain.transfer import TransferChain
state, errormsg = TransferChain().manual_transfer(
fileitem=fileitem,
target_storage=target_storage,

View File

@@ -1,14 +1,15 @@
import asyncio
import time
from pathlib import Path
from typing import List, Any, Optional
import jieba
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from pathlib import Path
from app import schemas
from app.agent import prompt_manager, agent_manager
from app.chain.storage import StorageChain
from app.core.config import settings, global_vars
from app.core.event import eventmanager
@@ -24,13 +25,99 @@ from app.schemas.types import EventType
router = APIRouter()
def _start_ai_redo_task(history_id: int, progress_key: str):
from app.agent import agent_manager
def normalize_history_ids(history_ids: list[int]) -> list[int]:
"""对输入的历史记录 ID 列表进行规范化处理,去除重复项并保持原有顺序。"""
normalized_ids: list[int] = []
for history_id in history_ids:
if history_id not in normalized_ids:
normalized_ids.append(history_id)
return normalized_ids
def build_manual_redo_template_context(history: TransferHistory) -> dict[str, int | str]:
"""仅负责把整理历史对象映射成 System Tasks 需要的模板变量。"""
src_fileitem = history.src_fileitem or {}
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
source_path = source_path or history.src or ""
season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip()
return {
"history_id": history.id,
"current_status": "success" if history.status else "failed",
"recognized_title": history.title or "unknown",
"media_type": history.type or "unknown",
"category": history.category or "unknown",
"year": history.year or "unknown",
"season_episode": season_episode or "unknown",
"source_path": source_path or "unknown",
"source_storage": history.src_storage or "local",
"destination_path": history.dest or "unknown",
"destination_storage": history.dest_storage or "unknown",
"transfer_mode": history.mode or "unknown",
"tmdbid": history.tmdbid or "none",
"doubanid": history.doubanid or "none",
"error_message": history.errmsg or "none",
}
def format_manual_redo_record_context(history: Any) -> str:
"""把单条整理记录格式化为批量任务可直接消费的上下文块。"""
context = build_manual_redo_template_context(history)
return "\n".join(
[
f"Record #{context['history_id']}:",
f"- Current status: {context['current_status']}",
f"- Current recognized title: {context['recognized_title']}",
f"- Media type: {context['media_type']}",
f"- Category: {context['category']}",
f"- Year: {context['year']}",
f"- Season/Episode: {context['season_episode']}",
f"- Source path: {context['source_path']}",
f"- Source storage: {context['source_storage']}",
f"- Destination path: {context['destination_path']}",
f"- Destination storage: {context['destination_storage']}",
f"- Transfer mode: {context['transfer_mode']}",
f"- Current TMDB ID: {context['tmdbid']}",
f"- Current Douban ID: {context['doubanid']}",
f"- Error message: {context['error_message']}",
]
)
def build_manual_redo_prompt(history: Any) -> str:
"""构建手动 AI 整理提示词。"""
return prompt_manager.render_system_task_message(
"manual_transfer_redo",
template_context=build_manual_redo_template_context(history),
)
def build_batch_manual_redo_template_context(
histories: list[Any],
) -> dict[str, int | str]:
"""仅负责把多条整理历史对象映射成批量 System Tasks 需要的模板变量。"""
return {
"history_ids_csv": ", ".join(str(history.id) for history in histories),
"history_count": len(histories),
"records_context": "\n\n".join(
format_manual_redo_record_context(history) for history in histories
),
}
def build_batch_manual_redo_prompt(histories: list[Any]) -> str:
"""构建批量手动 AI 整理提示词。"""
return prompt_manager.render_system_task_message(
"batch_manual_transfer_redo",
template_context=build_batch_manual_redo_template_context(histories),
)
def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
"""在后台线程中启动单条 AI 重新整理任务,并通过 ProgressHelper 实时更新进度。"""
progress = ProgressHelper(progress_key)
progress.start()
progress.update(
text=f"智能助正在准备整理记录 #{history_id} ...",
text=f"智能助正在准备整理记录 #{history_id} ...",
data={"history_id": history_id, "success": True},
)
@@ -39,9 +126,11 @@ def _start_ai_redo_task(history_id: int, progress_key: str):
async def runner():
try:
await agent_manager.manual_redo_transfer(
history_id=history_id,
await agent_manager.run_background_prompt(
message=prompt,
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=update_output,
suppress_user_reply=True,
)
progress.update(
text="智能助手整理完成",
@@ -63,6 +152,50 @@ def _start_ai_redo_task(history_id: int, progress_key: str):
asyncio.run_coroutine_threadsafe(runner(), global_vars.loop)
def _start_batch_ai_redo_task(
history_ids: list[int],
prompt: str,
progress_key: str,
):
"""在后台线程中启动批量 AI 重新整理任务,并通过 ProgressHelper 实时更新进度。"""
progress = ProgressHelper(progress_key)
progress.start()
progress.update(
text=f"智能助手正在准备批量整理 {len(history_ids)} 条记录 ...",
data={"history_ids": history_ids, "success": True},
)
def update_output(text: str):
progress.update(text=text, data={"history_ids": history_ids})
async def runner():
try:
await agent_manager.run_background_prompt(
message=prompt,
session_prefix="__agent_manual_redo_batch",
output_callback=update_output,
suppress_user_reply=True,
)
progress.update(
text="智能助手批量整理完成",
data={"history_ids": history_ids, "success": True, "completed": True},
)
except Exception as e:
progress.update(
text=f"智能助手批量整理失败:{str(e)}",
data={
"history_ids": history_ids,
"success": False,
"completed": True,
"error": str(e),
},
)
finally:
progress.end()
asyncio.run_coroutine_threadsafe(runner(), global_vars.loop)
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
async def download_history(page: Optional[int] = 1,
count: Optional[int] = 30,
@@ -159,9 +292,9 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
@router.post("/transfer/{history_id}/ai-redo", summary="智能助手重新整理", response_model=schemas.Response)
def ai_redo_transfer_history(
history_id: int,
db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser),
history_id: int,
db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser),
) -> Any:
"""
手动触发单条历史记录的 AI 重新整理,并返回进度键。
@@ -173,12 +306,62 @@ def ai_redo_transfer_history(
if not history:
return schemas.Response(success=False, message="整理记录不存在")
prompt = build_manual_redo_prompt(history)
progress_key = f"ai_redo_transfer_{history_id}_{int(time.time() * 1000)}"
_start_ai_redo_task(history_id=history_id, progress_key=progress_key)
_start_ai_redo_task(
history_id=history_id,
prompt=prompt,
progress_key=progress_key,
)
return schemas.Response(success=True, data={"progress_key": progress_key})
@router.post("/transfer/ai-redo", summary="智能助手批量重新整理", response_model=schemas.Response)
def batch_ai_redo_transfer_history(
payload: schemas.BatchTransferHistoryRedoRequest,
db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser),
) -> Any:
"""
手动触发多条历史记录的 AI 批量重新整理,并返回进度键。
"""
if not settings.AI_AGENT_ENABLE:
return schemas.Response(success=False, message="MoviePilot智能助手未启用")
history_ids = normalize_history_ids(payload.history_ids)
if not history_ids:
return schemas.Response(success=False, message="未提供有效的整理记录")
histories = []
missing_ids = []
for history_id in history_ids:
history = TransferHistory.get(db, history_id)
if not history:
missing_ids.append(history_id)
continue
histories.append(history)
if missing_ids:
return schemas.Response(
success=False,
message="整理记录不存在: " + ", ".join(str(history_id) for history_id in missing_ids),
)
prompt = build_batch_manual_redo_prompt(histories)
progress_key = f"ai_redo_transfer_batch_{int(time.time() * 1000)}"
_start_batch_ai_redo_task(
history_ids=history_ids,
prompt=prompt,
progress_key=progress_key,
)
return schemas.Response(
success=True,
data={"progress_key": progress_key, "history_ids": history_ids},
)
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:

View File

@@ -5,9 +5,9 @@ from fastapi import APIRouter, Depends, Body, Request
from fastapi.responses import StreamingResponse
from app import schemas
from app.chain.agent import AIRecommendChain
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.chain.ai_recommend import AIRecommendChain
from app.core.config import settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
@@ -73,7 +73,6 @@ async def search_by_id_stream(request: Request,
"""
根据TMDBID/豆瓣ID渐进式搜索站点资源返回格式为SSE
"""
AIRecommendChain().cancel_ai_recommend()
media_type = MediaType(mtype) if mtype else None
media_season = int(season) if season else None
@@ -206,8 +205,7 @@ async def search_by_id(mediaid: str,
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
"""
# 取消正在运行的AI推荐会清除数据库缓存
AIRecommendChain().cancel_ai_recommend()
if mtype:
media_type = MediaType(mtype)
else:
@@ -332,7 +330,6 @@ async def search_by_title_stream(request: Request,
"""
根据名称渐进式模糊搜索站点资源返回格式为SSE
"""
AIRecommendChain().cancel_ai_recommend()
event_source = SearchChain().async_search_by_title_stream(
title=keyword,
@@ -352,8 +349,7 @@ async def search_by_title(keyword: Optional[str] = None,
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""
# 取消正在运行的AI推荐并清除数据库缓存
AIRecommendChain().cancel_ai_recommend()
torrents = await SearchChain().async_search_by_title(
title=keyword, page=page,
sites=_parse_site_list(sites),
@@ -396,9 +392,9 @@ async def recommend_search_results(
return schemas.Response(success=False, message="没有可用的搜索结果", data={
"status": "error"
})
recommend_chain = AIRecommendChain()
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
if force:
# 检查功能是否启用
@@ -413,7 +409,7 @@ async def recommend_search_results(
return schemas.Response(success=True, data={
"status": "running"
})
# 如果是仅检查模式,不传递 filtered_indices避免触发请求变化检测
if check_only:
# 返回当前运行状态,不做任何任务启动或取消操作
@@ -423,14 +419,14 @@ async def recommend_search_results(
error_msg = current_status.pop("error", "未知错误")
return schemas.Response(success=False, message=error_msg, data=current_status)
return schemas.Response(success=True, data=current_status)
# 获取当前状态(会检测请求是否变化)
status_data = recommend_chain.get_status(filtered_indices, len(results))
# 如果功能未启用,直接返回禁用状态
if status_data.get("status") == "disabled":
return schemas.Response(success=True, data=status_data)
# 如果是空闲状态,启动新任务
if status_data["status"] == "idle":
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
@@ -438,11 +434,11 @@ async def recommend_search_results(
return schemas.Response(success=True, data={
"status": "running"
})
# 如果有错误将错误信息放到message中
if status_data.get("status") == "error":
error_msg = status_data.pop("error", "未知错误")
return schemas.Response(success=False, message=error_msg, data=status_data)
# 返回当前状态
return schemas.Response(success=True, data=status_data)

View File

@@ -1,13 +1,13 @@
import re
from typing import List, Optional, Dict, Any
import asyncio
import hashlib
import json
import re
from typing import Any, Dict, List, Optional
from app.agent import agent_manager, prompt_manager
from app.chain import ChainBase
from app.core.config import settings
from app.log import logger
from app.utils.common import log_execution_time
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
@@ -16,17 +16,16 @@ class AIRecommendChain(ChainBase, metaclass=Singleton):
"""
AI推荐处理链单例运行
用于基于搜索结果的AI智能推荐
使用 agent_manager.run_background_prompt 统一后台任务机制
"""
# 缓存文件名
__ai_indices_cache_file = "__ai_recommend_indices__"
# AI推荐状态
_ai_recommend_running = False
_ai_recommend_task: Optional[asyncio.Task] = None
_current_request_hash: Optional[str] = None # 当前请求的哈希值
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存索引列表
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
_current_request_hash: Optional[str] = None
_ai_recommend_result: Optional[List[int]] = None
_ai_recommend_error: Optional[str] = None
@staticmethod
def _calculate_request_hash(
@@ -53,7 +52,6 @@ class AIRecommendChain(ChainBase, metaclass=Singleton):
def _build_status(self) -> Dict[str, Any]:
"""
构建AI推荐状态字典
:return: 状态字典
"""
if not self.is_enabled:
return {"status": "disabled"}
@@ -61,13 +59,11 @@ class AIRecommendChain(ChainBase, metaclass=Singleton):
if self._ai_recommend_running:
return {"status": "running"}
# 尝试从数据库加载缓存
if self._ai_recommend_result is None:
cached_indices = self.load_cache(self.__ai_indices_cache_file)
if cached_indices is not None:
self._ai_recommend_result = cached_indices
# 只要有结果始终返回completed状态和数据
if self._ai_recommend_result is not None:
return {"status": "completed", "results": self._ai_recommend_result}
@@ -89,76 +85,16 @@ class AIRecommendChain(ChainBase, metaclass=Singleton):
获取AI推荐状态并检查请求是否变化用于首次请求或force模式
如果请求变化筛选条件变化返回idle状态
"""
# 计算当前请求的hash
request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
# 检查请求是否变化
is_same_request = request_hash == self._current_request_hash
# 如果请求变化了筛选条件改变返回idle状态
if not is_same_request:
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
# 请求未变化,返回当前实际状态
return self._build_status()
@log_execution_time(logger=logger)
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
"""
AI推荐
:param items: 候选资源列表(JSON字符串格式)
:param preference: 用户偏好(可选)
:return: AI返回的推荐结果
"""
# 设置运行状态
self._ai_recommend_running = True
try:
# 导入LLMHelper
from app.helper.llm import LLMHelper
# 获取LLM实例
llm = LLMHelper.get_llm()
# 构建提示词
user_preference = (
preference
or settings.AI_RECOMMEND_USER_PREFERENCE
or "Prefer high-quality resources with more seeders"
)
# 添加指令
instruction = """
Task: Select the best matching items from the list based on user preferences.
Each item contains:
- index: Item number
- title: Full torrent title
- size: File size
- seeders: Number of seeders
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
"""
message = (
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
+ "\n".join(items)
)
# 调用LLM
response = await llm.ainvoke(message)
return response.content
except ValueError as e:
logger.error(f"AI推荐配置错误: {e}")
raise
except Exception as e:
raise
finally:
# 清除运行状态
self._ai_recommend_running = False
self._ai_recommend_task = None
def is_ai_recommend_running(self) -> bool:
"""
检查AI推荐是否正在运行
@@ -186,44 +122,34 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do
) -> None:
"""
启动AI推荐任务
使用 agent_manager.run_background_prompt 后台Agent机制执行推荐
:param filtered_indices: 筛选后的索引列表
:param search_results_count: 搜索结果总数
:param results: 搜索结果列表
"""
# 防护检查确保AI推荐功能已启用
if not self.is_enabled:
logger.warning("AI推荐功能未启用跳过任务执行")
return
# 计算新请求的哈希值
new_request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
# 如果请求变化了,取消旧任务
if new_request_hash != self._current_request_hash:
self.cancel_ai_recommend()
# 更新请求哈希值
self._current_request_hash = new_request_hash
# 重置状态
self._ai_recommend_result = None
self._ai_recommend_error = None
# 启动新任务
async def run_recommend():
# 获取当前任务对象用于在finally中比对
current_task = asyncio.current_task()
try:
self._ai_recommend_running = True
# 准备数据
items = []
valid_indices = []
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
if filtered_indices is not None and len(filtered_indices) > 0:
results_to_process = [
results[i]
@@ -259,27 +185,54 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do
self._ai_recommend_error = "没有可用于AI推荐的资源"
return
# 调用AI推荐
ai_response = await self.async_ai_recommend(items)
user_preference = (
settings.AI_RECOMMEND_USER_PREFERENCE
or "Prefer high-quality resources with more seeders"
)
search_results_text = "User Preference: {preference}\n\nCandidate Resources:\n{items}".format(
preference=user_preference, items="\n".join(items)
)
prompt = prompt_manager.render_system_task_message(
"search_recommend",
template_context={"search_results": search_results_text},
)
full_output = [""]
def on_output(text: str):
full_output[0] = text
await agent_manager.run_background_prompt(
message=prompt,
session_prefix="__agent_search_recommend",
output_callback=on_output,
suppress_user_reply=True,
)
ai_response = full_output[0]
if not ai_response:
self._ai_recommend_error = "AI推荐未返回结果"
return
# 解析AI返回的索引
try:
# 使用正则提取JSON数组非贪婪模式避免匹配多个数组
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
json_match = re.search(r"\[.*?]", ai_response, re.DOTALL)
if not json_match:
raise ValueError(ai_response)
raise ValueError(f"无法从响应中提取JSON数组: {ai_response}")
ai_indices = json.loads(json_match.group())
if not isinstance(ai_indices, list):
raise ValueError(f"AI返回格式错误: {ai_response}")
# 映射回原始索引
if filtered_indices:
original_indices = [
filtered_indices[valid_indices[i]]
for i in ai_indices
if i < len(valid_indices)
and 0 <= filtered_indices[valid_indices[i]] < len(results)
and 0
<= filtered_indices[valid_indices[i]]
< len(results)
]
else:
original_indices = [
@@ -289,10 +242,7 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do
and 0 <= valid_indices[i] < len(results)
]
# 只返回索引列表,不返回完整数据
self._ai_recommend_result = original_indices
# 保存到数据库
self.save_cache(original_indices, self.__ai_indices_cache_file)
logger.info(f"AI推荐完成: {len(original_indices)}")
@@ -308,11 +258,8 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do
logger.error(f"AI推荐任务失败: {e}")
self._ai_recommend_error = str(e)
finally:
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
# 如果任务被取消并启动了新任务self._ai_recommend_task 已经指向新任务,不应重置
if self._ai_recommend_task == current_task:
self._ai_recommend_running = False
self._ai_recommend_task = None
# 创建并启动任务
self._ai_recommend_task = asyncio.create_task(run_recommend())

View File

@@ -1,16 +1,15 @@
import asyncio
import base64
import mimetypes
import re
import time
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional, Dict, Union, List
from urllib.parse import unquote, urlparse
import uuid
import base64
from app.agent import agent_manager
from app.agent import agent_manager, prompt_manager
from app.chain import ChainBase
from app.chain.interaction import (
MediaInteractionChain,
@@ -20,6 +19,8 @@ from app.chain.interaction import (
from app.chain.skills import SkillsChain, skills_interaction_manager
from app.chain.transfer import TransferChain
from app.core.config import settings, global_vars
from app.db.models import TransferHistory
from app.db.transferhistory_oper import TransferHistoryOper
from app.helper.llm import LLMHelper
from app.helper.voice import VoiceHelper
from app.log import logger
@@ -92,17 +93,17 @@ class MessageChain(ChainBase):
)
def handle_message(
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
images: Optional[List[CommingMessage.MessageImage]] = None,
audio_refs: Optional[List[str]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
images: Optional[List[CommingMessage.MessageImage]] = None,
audio_refs: Optional[List[str]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
) -> None:
"""
识别消息内容,执行操作
@@ -171,21 +172,21 @@ class MessageChain(ChainBase):
if skills_interaction_manager.get_by_user(userid):
if SkillsChain().handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
):
return
if media_interaction_manager.get_by_user(userid):
if MediaInteractionChain().handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
):
return
@@ -202,8 +203,8 @@ class MessageChain(ChainBase):
return
if (
settings.AI_AGENT_ENABLE
and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input)
settings.AI_AGENT_ENABLE
and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input)
):
self._handle_ai_message(
text=text,
@@ -217,11 +218,11 @@ class MessageChain(ChainBase):
return
if MediaInteractionChain().handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
):
return
@@ -236,14 +237,14 @@ class MessageChain(ChainBase):
)
def _handle_callback(
self,
text: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
self,
text: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
"""
处理按钮回调
@@ -254,44 +255,44 @@ class MessageChain(ChainBase):
logger.info(f"处理按钮回调:{callback_data}")
if self._handle_transfer_callback(
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
):
return
if SkillsChain().handle_callback_interaction(
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
if MediaInteractionChain().handle_callback_interaction(
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
if self._handle_agent_choice_callback(
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
@@ -327,14 +328,14 @@ class MessageChain(ChainBase):
@staticmethod
def _parse_transfer_callback(
callback_data: str,
callback_data: str,
) -> Optional[tuple[str, int]]:
"""
解析整理失败通知按钮回调。
"""
for prefix, action in (
("transfer_retry_", "retry"),
("transfer_ai_retry_", "ai_retry"),
("transfer_retry_", "retry"),
("transfer_ai_retry_", "ai_retry"),
):
if callback_data.startswith(prefix):
history_id = callback_data.replace(prefix, "", 1)
@@ -343,12 +344,12 @@ class MessageChain(ChainBase):
return None
def _handle_transfer_callback(
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
) -> bool:
"""
处理整理失败通知中的重试类按钮。
@@ -378,7 +379,7 @@ class MessageChain(ChainBase):
@staticmethod
def _parse_agent_choice_callback(
callback_data: str,
callback_data: str,
) -> Optional[tuple[str, int]]:
"""
解析 Agent 按钮选择回调。
@@ -401,14 +402,14 @@ class MessageChain(ChainBase):
return request_id, int(option_index)
def _handle_agent_choice_callback(
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> bool:
"""
将 Agent 按钮选择回传为同一会话中的下一条用户消息。
@@ -465,14 +466,14 @@ class MessageChain(ChainBase):
return True
def _update_interaction_message_feedback(
self,
channel: MessageChannel,
source: str,
original_message_id: Optional[Union[str, int]],
original_chat_id: Optional[str],
prompt: str,
selected_label: str,
title: Optional[str] = None,
self,
channel: MessageChannel,
source: str,
original_message_id: Optional[Union[str, int]],
original_chat_id: Optional[str],
prompt: str,
selected_label: str,
title: Optional[str] = None,
) -> None:
"""
在用户点击交互按钮后,立即更新原消息,明确显示已选择的内容。
@@ -494,12 +495,12 @@ class MessageChain(ChainBase):
)
def _retry_transfer_history(
self,
history_id: int,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
self,
history_id: int,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
) -> None:
"""
立即重新整理一条失败的整理记录。
@@ -541,16 +542,46 @@ class MessageChain(ChainBase):
)
def _take_over_transfer_history_by_ai(
self,
history_id: int,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
self,
history_id: int,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
) -> None:
"""
由智能助手接管一条失败的整理记录。
"""
def __build_manual_redo_prompt(his: TransferHistory) -> str:
"""构建手动 AI 整理提示词。"""
src_fileitem = his.src_fileitem or {}
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
source_path = source_path or his.src or ""
season_episode = f"{his.seasons or ''}{his.episodes or ''}".strip()
template_context = {
"his_id": his.id,
"current_status": "success" if his.status else "failed",
"recognized_title": his.title or "unknown",
"media_type": his.type or "unknown",
"category": his.category or "unknown",
"year": his.year or "unknown",
"season_episode": season_episode or "unknown",
"source_path": source_path or "unknown",
"source_storage": his.src_storage or "local",
"destination_path": his.dest or "unknown",
"destination_storage": his.dest_storage or "unknown",
"transfer_mode": his.mode or "unknown",
"tmdbid": his.tmdbid or "none",
"doubanid": his.doubanid or "none",
"error_message": his.errmsg or "none",
}
return prompt_manager.render_system_task_message(
"manual_transfer_redo",
template_context=template_context,
)
if not settings.AI_AGENT_ENABLE:
self.post_message(
Notification(
@@ -563,6 +594,23 @@ class MessageChain(ChainBase):
)
return
history = TransferHistoryOper().get(history_id)
if not history:
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="重新整理失败",
text=f"整理记录 #{history_id} 不存在",
link=settings.MP_DOMAIN("#/history"),
)
)
return
redo_prompt = __build_manual_redo_prompt(history)
self.post_message(
Notification(
channel=channel,
@@ -583,9 +631,11 @@ class MessageChain(ChainBase):
final_output = text_output or ""
try:
await agent_manager.manual_redo_transfer(
history_id=history_id,
await agent_manager.run_background_prompt(
message=redo_prompt,
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=_capture_output,
suppress_user_reply=True,
)
await self.async_post_message(
Notification(
@@ -595,7 +645,7 @@ class MessageChain(ChainBase):
username=username,
title="智能助手整理完成",
text=final_output.strip()
or f"整理记录 #{history_id} 已由智能助手处理完成。",
or f"整理记录 #{history_id} 已由智能助手处理完成。",
link=settings.MP_DOMAIN("#/history"),
)
)
@@ -650,12 +700,12 @@ class MessageChain(ChainBase):
self._user_sessions[userid] = (session_id, datetime.now())
def _record_user_message(
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
) -> None:
"""
保存一条用户消息到消息历史与数据库。
@@ -690,10 +740,10 @@ class MessageChain(ChainBase):
return False
def remote_clear_session(
self,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
self,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
):
"""
清除用户会话(远程命令接口)
@@ -735,10 +785,10 @@ class MessageChain(ChainBase):
)
def remote_stop_agent(
self,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
self,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
):
"""
应急停止当前正在执行的Agent推理远程命令接口
@@ -805,7 +855,7 @@ class MessageChain(ChainBase):
f"({context_ratio * 100:.2f}%)"
if context_ratio is not None
else f"{cls._format_token_count(last_input_tokens)} / "
f"{cls._format_token_count(context_window_tokens)}"
f"{cls._format_token_count(context_window_tokens)}"
)
else:
context_usage_text = "暂无模型调用数据"
@@ -825,10 +875,10 @@ class MessageChain(ChainBase):
return "\n".join(lines)
def remote_session_status(
self,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
self,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
):
"""查询当前用户的智能体会话状态。"""
session_info = self._user_sessions.get(userid)
@@ -856,15 +906,15 @@ class MessageChain(ChainBase):
)
def _handle_ai_message(
self,
text: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
images: Optional[List[CommingMessage.MessageImage]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
session_id: Optional[str] = None,
self,
text: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
images: Optional[List[CommingMessage.MessageImage]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
session_id: Optional[str] = None,
) -> None:
"""
处理AI智能体消息
@@ -928,10 +978,10 @@ class MessageChain(ChainBase):
elif images:
image_attachments = self._build_image_attachments(images)
if (
original_images
and not image_attachments
and not user_message
and not files
original_images
and not image_attachments
and not user_message
and not files
):
self.post_message(
Notification(
@@ -986,7 +1036,7 @@ class MessageChain(ChainBase):
)
def _transcribe_audio_refs(
self, audio_refs: List[str], channel: MessageChannel, source: str
self, audio_refs: List[str], channel: MessageChannel, source: str
) -> Optional[str]:
"""
下载并识别语音消息,仅处理当前已接入的渠道。
@@ -1119,10 +1169,10 @@ class MessageChain(ChainBase):
return default
def _download_attachments_to_data_urls(
self,
attachments: List[CommingMessage.MessageImage],
channel: MessageChannel,
source: str,
self,
attachments: List[CommingMessage.MessageImage],
channel: MessageChannel,
source: str,
) -> Optional[List[str]]:
"""
下载可直接提供给 LLM 的附件内容,并统一转换为 data URL。
@@ -1147,7 +1197,7 @@ class MessageChain(ChainBase):
if base64_data:
data_urls.append(f"data:image/jpeg;base64,{base64_data}")
elif attachment_ref.startswith(
"wxwork://media_id/"
"wxwork://media_id/"
) or attachment_ref.startswith(
"wxbot://image/"
):
@@ -1208,7 +1258,7 @@ class MessageChain(ChainBase):
return data_urls if data_urls else None
def _build_image_attachments(
self, images: List[CommingMessage.MessageImage]
self, images: List[CommingMessage.MessageImage]
) -> List[CommingMessage.MessageAttachment]:
"""
将图片引用转换为附件描述,以便按文件方式交给 Agent 处理。
@@ -1235,11 +1285,11 @@ class MessageChain(ChainBase):
return attachments
def _prepare_agent_files(
self,
session_id: str,
files: Optional[List[CommingMessage.MessageAttachment]],
channel: MessageChannel,
source: str,
self,
session_id: str,
files: Optional[List[CommingMessage.MessageAttachment]],
channel: MessageChannel,
source: str,
) -> Optional[List[dict]]:
"""
下载用户上传的附件,落盘到临时目录,并生成 Agent 可消费的文件描述。
@@ -1286,7 +1336,7 @@ class MessageChain(ChainBase):
return prepared_files or None
def _download_message_file_bytes(
self, file_ref: str, channel: MessageChannel, source: str
self, file_ref: str, channel: MessageChannel, source: str
) -> Optional[bytes]:
"""
下载消息附件的原始字节内容。
@@ -1359,11 +1409,11 @@ class MessageChain(ChainBase):
return None
def _save_agent_attachment(
self,
session_id: str,
filename: Optional[str],
content: bytes,
mime_type: Optional[str] = None,
self,
session_id: str,
filename: Optional[str],
content: bytes,
mime_type: Optional[str] = None,
) -> Path:
"""
将用户上传文件写入临时目录,并返回本地路径。
@@ -1379,7 +1429,7 @@ class MessageChain(ChainBase):
@staticmethod
def _sanitize_attachment_name(
filename: Optional[str], mime_type: Optional[str] = None
filename: Optional[str], mime_type: Optional[str] = None
) -> str:
"""
规范化附件文件名,避免路径穿越和非法字符。
@@ -1449,5 +1499,6 @@ class MessageChain(ChainBase):
return None
try:
return base64.b64decode(payload)
except Exception:
except Exception as e:
logger.error(e)
return None

View File

@@ -724,8 +724,7 @@ class SkillsChain(ChainBase):
"""
if request.view == "installed":
title, text, buttons = self._build_installed_view(
request=request,
force_market_refresh=force_market_refresh,
request=request
)
elif request.view == "market":
title, text, buttons = self._build_market_view(
@@ -735,7 +734,6 @@ class SkillsChain(ChainBase):
elif request.view == "sources":
title, text, buttons = self._build_sources_view(
request=request,
force_market_refresh=force_market_refresh,
)
else:
title, text, buttons = self._build_root_view(
@@ -808,8 +806,7 @@ class SkillsChain(ChainBase):
def _build_installed_view(
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False, # noqa: ARG002
request: PendingSkillsInteraction
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建已安装技能视图,列出来源和可删除状态。
@@ -971,7 +968,6 @@ class SkillsChain(ChainBase):
def _build_sources_view(
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False, # noqa: ARG002
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建技能源管理视图,提供自定义 GitHub 源的增删入口。

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from typing import List, Optional, Tuple, Union, Dict, Callable
from app import schemas
from app.agent import prompt_manager, agent_manager
from app.chain import ChainBase
from app.chain.media import MediaChain
from app.chain.storage import StorageChain
@@ -162,10 +163,10 @@ class JobManager:
else:
# 不重复添加任务
if any(
[
t.fileitem == task.fileitem
for t in self._job_view[__mediaid__].tasks
]
[
t.fileitem == task.fileitem
for t in self._job_view[__mediaid__].tasks
]
):
logger.debug(f"任务 {task.fileitem.name} 已存在,跳过重复添加")
return False
@@ -301,7 +302,7 @@ class JobManager:
return task
def __remove_task_with_job_id(
self, fileitem: FileItem
self, fileitem: FileItem
) -> Tuple[Optional[TransferJobTask], Optional[Tuple]]:
"""
根据文件项移除任务并返回任务所在的作业ID
@@ -462,10 +463,10 @@ class JobManager:
"""
with job_lock:
if any(
task.state not in {"completed", "failed"}
for job in self._job_view.values()
for task in job.tasks
if task.download_hash == download_hash
task.state not in {"completed", "failed"}
for job in self._job_view.values()
for task in job.tasks
if task.download_hash == download_hash
):
return False
return True
@@ -476,19 +477,19 @@ class JobManager:
"""
with job_lock:
if any(
task.state != "completed"
for job in self._job_view.values()
for task in job.tasks
if task.download_hash == download_hash
task.state != "completed"
for job in self._job_view.values()
for task in job.tasks
if task.download_hash == download_hash
):
return False
return True
def has_tasks(
self,
meta: MetaBase,
mediainfo: Optional[MediaInfo] = None,
season: Optional[int] = None,
self,
meta: MetaBase,
mediainfo: Optional[MediaInfo] = None,
season: Optional[int] = None,
) -> bool:
"""
判断作业是否还有任务正在处理
@@ -501,12 +502,12 @@ class JobManager:
__metaid__ = self.__get_meta_id(meta=meta, season=season)
return (
__metaid__ in self._job_view
and len(self._job_view[__metaid__].tasks) > 0
__metaid__ in self._job_view
and len(self._job_view[__metaid__].tasks) > 0
)
def success_tasks(
self, media: MediaInfo, season: Optional[int] = None
self, media: MediaInfo, season: Optional[int] = None
) -> List[TransferJobTask]:
"""
获取作业中所有成功的任务
@@ -522,7 +523,7 @@ class JobManager:
]
def all_tasks(
self, media: MediaInfo, season: Optional[int] = None
self, media: MediaInfo, season: Optional[int] = None
) -> List[TransferJobTask]:
"""
获取作业中全部任务
@@ -586,7 +587,7 @@ class JobManager:
return list(self._job_view.values())
def season_episodes(
self, media: MediaInfo, season: Optional[int] = None
self, media: MediaInfo, season: Optional[int] = None
) -> List[int]:
"""
获取作业的季集清单
@@ -596,6 +597,107 @@ class JobManager:
return self._season_episodes.get(__mediaid__) or []
class FailedRetryScheduler:
"""
负责失败整理记录的 debounce 聚合与 AI 重试调度。
"""
RETRY_TRANSFER_DEBOUNCE_SECONDS = 300
def __init__(self):
super().__init__()
self._retry_transfer_buffer: dict[str, list[int]] = {}
self._retry_transfer_timers: dict[str, asyncio.TimerHandle] = {}
self._retry_transfer_lock = asyncio.Lock()
async def close(self):
async with self._retry_transfer_lock:
timers = list(self._retry_transfer_timers.values())
self._retry_transfer_timers.clear()
self._retry_transfer_buffer.clear()
for timer in timers:
timer.cancel()
@staticmethod
def _build_retry_transfer_template_context(
history_ids: list[int],
) -> tuple[str, dict[str, int | str]]:
"""仅负责把失败重试任务的动态数据映射成模板变量。"""
is_batch = len(history_ids) > 1
task_type = "batch_transfer_failed_retry" if is_batch else "transfer_failed_retry"
template_context: dict[str, int | str] = {
"history_ids_csv": ", ".join(str(item) for item in history_ids),
"history_count": len(history_ids),
}
if not is_batch:
template_context["history_id"] = history_ids[0]
return task_type, template_context
def _build_retry_transfer_prompt(self, history_ids: list[int]) -> str:
"""根据失败记录数量构建统一的重试整理后台任务提示词。"""
task_type, template_context = self._build_retry_transfer_template_context(history_ids)
return prompt_manager.render_system_task_message(
task_type,
template_context=template_context,
)
async def schedule_retry(self, history_id: int, group_key: str = ""):
"""
同一 group_key 的失败记录会在缓冲期内合并为一次 agent 调用。
"""
if not group_key:
group_key = f"_default_{history_id}"
async with self._retry_transfer_lock:
if group_key not in self._retry_transfer_buffer:
self._retry_transfer_buffer[group_key] = []
if history_id not in self._retry_transfer_buffer[group_key]:
self._retry_transfer_buffer[group_key].append(history_id)
logger.info(
f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 "
f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)"
)
if group_key in self._retry_transfer_timers:
self._retry_transfer_timers[group_key].cancel()
loop = asyncio.get_running_loop()
self._retry_transfer_timers[group_key] = loop.call_later(
self.RETRY_TRANSFER_DEBOUNCE_SECONDS,
lambda gk=group_key: asyncio.create_task(self._flush_retry_transfer(gk)),
)
async def _flush_retry_transfer(self, group_key: str):
"""
延迟定时器到期后,取出该分组的所有 history_id 并合并为一次 agent 调用。
"""
async with self._retry_transfer_lock:
history_ids = self._retry_transfer_buffer.pop(group_key, [])
self._retry_transfer_timers.pop(group_key, None)
if not history_ids:
return
ids_str = ", ".join(str(item) for item in history_ids)
logger.info(
f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})"
)
try:
await agent_manager.run_background_prompt(
message=self._build_retry_transfer_prompt(history_ids),
session_prefix="__agent_retry_transfer_batch",
)
logger.info(
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"
)
except Exception as err:
logger.error(
f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {err}"
)
class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
"""
文件整理处理链
@@ -623,6 +725,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
self._transfer_interval = 15
# 事件管理器
self.jobview = JobManager()
# Agent重试管理器
self.retry_scheduler = FailedRetryScheduler()
# 转移成功的文件清单
self._success_target_files: Dict[str, List[str]] = {}
# 整理进度进度
@@ -713,7 +817,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def __default_callback(
self, task: TransferTask, transferinfo: TransferInfo, /
self, task: TransferTask, transferinfo: TransferInfo, /
) -> Tuple[bool, str]:
"""
整理完成后处理
@@ -730,12 +834,12 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
"""
# 更新文件数量
transferinfo.file_count = (
self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
)
# 更新文件大小
transferinfo.total_size = (
self.jobview.size(task.mediainfo, task.meta.begin_season)
or task.fileitem.size
self.jobview.size(task.mediainfo, task.meta.begin_season)
or task.fileitem.size
)
# 更新文件清单
with job_lock:
@@ -866,13 +970,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# AI智能体自动重试整理
if (
history
and settings.AI_AGENT_ENABLE
and settings.AI_AGENT_RETRY_TRANSFER
history
and settings.AI_AGENT_ENABLE
and settings.AI_AGENT_RETRY_TRANSFER
):
try:
from app.agent import agent_manager
# 使用 download_hash 或源文件父目录作为分组键,
# 同一批次如同一个种子的失败记录会被合并为一次agent调用
group_key = (
@@ -881,7 +983,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
else ""
)
asyncio.run_coroutine_threadsafe(
agent_manager.retry_failed_transfer(
self.retry_scheduler.schedule_retry(
history.id, group_key=group_key
),
global_vars.loop,
@@ -996,11 +1098,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
if self.jobview.is_torrent_success(t.download_hash):
processed_hashes.add(t.download_hash)
if self._can_delete_torrent(
t.download_hash, t.downloader, transfer_exclude_words
t.download_hash, t.downloader, transfer_exclude_words
):
# 移除种子及文件
if self.remove_torrents(
t.download_hash, downloader=t.downloader
t.download_hash, downloader=t.downloader
):
logger.info(
f"移动模式删除种子成功:{t.download_hash}"
@@ -1156,7 +1258,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.error(f"整理队列处理出现错误:{e} - {traceback.format_exc()}")
def __handle_transfer(
self, task: TransferTask, callback: Optional[Callable] = None
self, task: TransferTask, callback: Optional[Callable] = None
) -> Optional[Tuple[bool, str]]:
"""
处理整理任务
@@ -1223,13 +1325,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# AI智能体自动重试整理
if (
his
and settings.AI_AGENT_ENABLE
and settings.AI_AGENT_RETRY_TRANSFER
his
and settings.AI_AGENT_ENABLE
and settings.AI_AGENT_RETRY_TRANSFER
):
try:
from app.agent import agent_manager
# 使用 download_hash 或源文件父目录作为分组键
group_key = (
task.download_hash
@@ -1238,7 +1338,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
else ""
)
asyncio.run_coroutine_threadsafe(
agent_manager.retry_failed_transfer(
self.retry_scheduler.schedule_retry(
his.id, group_key=group_key
),
global_vars.loop,
@@ -1393,8 +1493,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 如果没有下载器监控的目录则不处理
if not any(
dir_info.monitor_type == "downloader" and dir_info.storage == "local"
for dir_info in download_dirs
dir_info.monitor_type == "downloader" and dir_info.storage == "local"
for dir_info in download_dirs
):
return True
@@ -1408,8 +1508,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
torrent
for torrent in torrents_list
if (h := torrent.hash) not in existing_hashes
# 排除多下载器返回的重复种子
and (h not in seen and (seen.add(h) or True))
# 排除多下载器返回的重复种子
and (h not in seen and (seen.add(h) or True))
]
else:
torrents = []
@@ -1480,7 +1580,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
fileitem=FileItem(
storage="local",
path=file_path.as_posix()
+ ("/" if file_path.is_dir() else ""),
+ ("/" if file_path.is_dir() else ""),
type="dir" if not file_path.is_file() else "file",
name=file_path.name,
size=file_path.stat().st_size,
@@ -1498,10 +1598,10 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return True
def __get_trans_fileitems(
self,
fileitem: FileItem,
predicate: Optional[Callable[[FileItem, bool], bool]],
verify_file_exists: bool = True,
self,
fileitem: FileItem,
predicate: Optional[Callable[[FileItem, bool], bool]],
verify_file_exists: bool = True,
) -> List[Tuple[FileItem, bool]]:
"""
获取待整理文件项列表
@@ -1541,7 +1641,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
def _apply_predicate(
file_item: FileItem, is_bluray_dir: bool
file_item: FileItem, is_bluray_dir: bool
) -> List[Tuple[FileItem, bool]]:
if predicate is None or predicate(file_item, is_bluray_dir):
return [(file_item, is_bluray_dir)]
@@ -1586,10 +1686,10 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
@staticmethod
def _resolve_download_history(
downloadhis: DownloadHistoryOper,
file_path: Path,
bluray_dir: bool = False,
download_hash: Optional[str] = None,
downloadhis: DownloadHistoryOper,
file_path: Path,
bluray_dir: bool = False,
download_hash: Optional[str] = None,
) -> Optional[DownloadHistory]:
"""
根据显式 hash、文件路径或种子根目录回查下载历史。
@@ -1624,26 +1724,26 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
def do_transfer(
self,
fileitem: FileItem,
meta: MetaBase = None,
mediainfo: MediaInfo = None,
target_directory: TransferDirectoryConf = None,
target_storage: Optional[str] = None,
target_path: Path = None,
transfer_type: Optional[str] = None,
scrape: Optional[bool] = None,
library_type_folder: Optional[bool] = None,
library_category_folder: Optional[bool] = None,
season: Optional[int] = None,
epformat: EpisodeFormat = None,
min_filesize: Optional[int] = 0,
downloader: Optional[str] = None,
download_hash: Optional[str] = None,
force: Optional[bool] = False,
background: Optional[bool] = True,
manual: Optional[bool] = False,
continue_callback: Callable = None,
self,
fileitem: FileItem,
meta: MetaBase = None,
mediainfo: MediaInfo = None,
target_directory: TransferDirectoryConf = None,
target_storage: Optional[str] = None,
target_path: Path = None,
transfer_type: Optional[str] = None,
scrape: Optional[bool] = None,
library_type_folder: Optional[bool] = None,
library_category_folder: Optional[bool] = None,
season: Optional[int] = None,
epformat: EpisodeFormat = None,
min_filesize: Optional[int] = 0,
downloader: Optional[str] = None,
download_hash: Optional[str] = None,
force: Optional[bool] = False,
background: Optional[bool] = True,
manual: Optional[bool] = False,
continue_callback: Callable = None,
) -> Tuple[bool, str]:
"""
执行一个复杂目录的整理操作
@@ -1690,7 +1790,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 汇总错误信息
err_msgs: List[str] = []
def _filter(file_item: FileItem, is_bluray_dir: bool) -> bool:
def _filter(item: FileItem, is_bluray_dir: bool) -> bool:
"""
过滤文件项
@@ -1699,30 +1799,30 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
if continue_callback and not continue_callback():
raise OperationInterrupted()
# 有集自定义格式,过滤文件
if formaterHandler and not formaterHandler.match(file_item.name):
if formaterHandler and not formaterHandler.match(item.name):
return False
# 过滤后缀和大小(蓝光目录、附加文件不过滤)
if (
not is_bluray_dir
and not self.__is_subtitle_file(file_item)
and not self.__is_audio_file(file_item)
not is_bluray_dir
and not self.__is_subtitle_file(item)
and not self.__is_audio_file(item)
):
if not self.__is_media_file(file_item):
if not self.__is_media_file(item):
return False
if not self.__is_allow_filesize(file_item, min_filesize):
if not self.__is_allow_filesize(item, min_filesize):
return False
# 回收站及隐藏的文件不处理
if (
file_item.path.find("/@Recycle/") != -1
or file_item.path.find("/#recycle/") != -1
or file_item.path.find("/.") != -1
or file_item.path.find("/@eaDir") != -1
item.path.find("/@Recycle/") != -1
or item.path.find("/#recycle/") != -1
or item.path.find("/.") != -1
or item.path.find("/@eaDir") != -1
):
logger.debug(f"{file_item.path} 是回收站或隐藏的文件")
logger.debug(f"{item.path} 是回收站或隐藏的文件")
return False
# 整理屏蔽词不处理
if self._is_blocked_by_exclude_words(
file_item.path, transfer_exclude_words
item.path, transfer_exclude_words
):
return False
return True
@@ -1929,11 +2029,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return all_success, error_msg
def remote_transfer(
self,
arg_str: str,
channel: MessageChannel,
userid: Union[str, int] = None,
source: Optional[str] = None,
self,
arg_str: str,
channel: MessageChannel,
userid: Union[str, int] = None,
source: Optional[str] = None,
):
"""
远程重新整理,参数 历史记录ID TMDBID|类型
@@ -1945,7 +2045,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
channel=channel,
source=source,
title="请输入正确的命令格式:/redo [id] 或 /redo [id] [tmdbid/豆瓣id]|[类型]"
"[id] 为整理记录编号",
"[id] 为整理记录编号",
userid=userid,
)
)
@@ -2005,7 +2105,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
@staticmethod
def build_failed_transfer_buttons(
history_id: Optional[int],
history_id: Optional[int],
) -> Optional[List[List[dict]]]:
"""
构建整理失败通知的操作按钮。
@@ -2029,7 +2129,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return self.__re_transfer(logid=history_id)
def __re_transfer(
self, logid: int, mtype: MediaType = None, mediaid: Optional[str] = None
self, logid: int, mtype: MediaType = None, mediaid: Optional[str] = None
) -> Tuple[bool, str]:
"""
根据历史记录,重新识别整理,只支持简单条件
@@ -2088,25 +2188,25 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return True, ""
def manual_transfer(
self,
fileitem: FileItem,
target_storage: Optional[str] = None,
target_path: Path = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
mtype: MediaType = None,
season: Optional[int] = None,
episode_group: Optional[str] = None,
transfer_type: Optional[str] = None,
epformat: EpisodeFormat = None,
min_filesize: Optional[int] = 0,
scrape: Optional[bool] = None,
library_type_folder: Optional[bool] = None,
library_category_folder: Optional[bool] = None,
force: Optional[bool] = False,
background: Optional[bool] = False,
downloader: Optional[str] = None,
download_hash: Optional[str] = None,
self,
fileitem: FileItem,
target_storage: Optional[str] = None,
target_path: Path = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
mtype: MediaType = None,
season: Optional[int] = None,
episode_group: Optional[str] = None,
transfer_type: Optional[str] = None,
epformat: EpisodeFormat = None,
min_filesize: Optional[int] = 0,
scrape: Optional[bool] = None,
library_type_folder: Optional[bool] = None,
library_category_folder: Optional[bool] = None,
force: Optional[bool] = False,
background: Optional[bool] = False,
downloader: Optional[str] = None,
download_hash: Optional[str] = None,
) -> Tuple[bool, Union[str, list]]:
"""
手动整理,支持复杂条件,带进度显示
@@ -2194,12 +2294,12 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return state, errmsg
def send_transfer_message(
self,
meta: MetaBase,
mediainfo: MediaInfo,
transferinfo: TransferInfo,
season_episode: Optional[str] = None,
username: Optional[str] = None,
self,
meta: MetaBase,
mediainfo: MediaInfo,
transferinfo: TransferInfo,
season_episode: Optional[str] = None,
username: Optional[str] = None,
):
"""
发送入库成功的消息
@@ -2237,7 +2337,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return False
def _can_delete_torrent(
self, download_hash: str, downloader: str, transfer_exclude_words
self, download_hash: str, downloader: str, transfer_exclude_words
) -> bool:
"""
检查是否可以删除种子文件
@@ -2270,11 +2370,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
file_path = save_path / file.name
# 如果存在未被屏蔽的媒体文件,则不删除种子
if (
file_path.suffix in self._allowed_exts
and not self._is_blocked_by_exclude_words(
file_path.as_posix(), transfer_exclude_words
)
and file_path.exists()
file_path.suffix in self._allowed_exts
and not self._is_blocked_by_exclude_words(
file_path.as_posix(), transfer_exclude_words
)
and file_path.exists()
):
return False

View File

@@ -506,7 +506,7 @@ class ConfigModel(BaseModel):
# LLM模型名称
LLM_MODEL: str = "deepseek-chat"
# 思考模式/深度配置off/auto/minimal/low/medium/high/max/xhigh
LLM_THINKING_LEVEL: Optional[str] = 'off'
LLM_THINKING_LEVEL: Optional[str] = "off"
# LLM是否支持图片输入开启后消息图片会按多模态输入发送给模型
LLM_SUPPORT_IMAGE_INPUT: bool = True
# LLM是否支持音频输入输出开启后才会启用语音转写与语音回复

View File

@@ -1,6 +1,6 @@
from typing import Optional, Any
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class DownloadHistory(BaseModel):
@@ -97,3 +97,7 @@ class TransferHistory(BaseModel):
date: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class BatchTransferHistoryRedoRequest(BaseModel):
history_ids: list[int] = Field(default_factory=list)

View File

@@ -10,10 +10,10 @@ class AgentInitializer:
"""
AI智能体初始化器
"""
def __init__(self):
self._initialized = False
async def initialize(self) -> bool:
"""
初始化AI智能体管理器
@@ -22,16 +22,16 @@ class AgentInitializer:
if not settings.AI_AGENT_ENABLE:
logger.info("AI智能体功能未启用")
return True
await agent_manager.initialize()
self._initialized = True
logger.info("AI智能体管理器初始化成功")
return True
except Exception as e:
logger.error(f"AI智能体管理器初始化失败: {e}")
return False
async def cleanup(self) -> None:
"""
清理AI智能体管理器
@@ -39,11 +39,10 @@ class AgentInitializer:
try:
if not self._initialized:
return
await agent_manager.close()
self._initialized = False
logger.info("AI智能体管理器已关闭")
except Exception as e:
logger.error(f"关闭AI智能体管理器时发生错误: {e}")
@@ -60,7 +59,7 @@ def init_agent():
if not settings.AI_AGENT_ENABLE:
logger.info("AI智能体功能未启用")
return True
# 在新的事件循环中初始化AI智能体管理器
def run_init():
loop = asyncio.new_event_loop()
@@ -77,13 +76,13 @@ def init_agent():
return False
finally:
loop.close()
# 在后台线程中初始化
init_thread = threading.Thread(target=run_init, daemon=True)
init_thread.start()
return True
except Exception as e:
logger.error(f"初始化AI智能体时发生错误: {e}")
return False

View File

@@ -77,6 +77,21 @@ class TestAgentPromptStyle(unittest.TestCase):
self.assertIn("Total failed records: 1", message)
self.assertIn("history_id=7", message)
def test_render_batch_manual_transfer_redo_message(self):
message = prompt_manager.render_system_task_message(
"batch_manual_transfer_redo",
template_context={
"history_ids_csv": "7, 8",
"history_count": 2,
"records_context": "Record #7:\n- Source path: /downloads/a.mkv",
},
)
self.assertIn("[System Task - Batch Manual Transfer Re-Organize]", message)
self.assertIn("History IDs: 7, 8", message)
self.assertIn("Total records: 2", message)
self.assertIn("Record #7:", message)
def test_missing_system_task_template_context_raises_clear_error(self):
with self.assertRaises(PromptConfigError):
prompt_manager.render_system_task_message("transfer_failed_retry")

View File

@@ -1,6 +1,7 @@
import unittest
import sys
from types import ModuleType
from types import SimpleNamespace
from unittest.mock import patch
sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi"))
@@ -74,9 +75,33 @@ class TestTransferFailedRetryButtons(unittest.TestCase):
def test_transfer_ai_retry_callback_schedules_agent_takeover(self):
chain = MessageChain()
history = SimpleNamespace(
id=34,
status=False,
title="Test Show",
type="电视剧",
category=None,
year="2024",
seasons="S01",
episodes="E01",
src="/downloads/Test.Show.S01E01.mkv",
src_storage="local",
src_fileitem={"path": "/downloads/Test.Show.S01E01.mkv"},
dest=None,
dest_storage=None,
mode="copy",
tmdbid=123,
doubanid=None,
errmsg="未识别到媒体信息",
)
with patch.object(settings, "AI_AGENT_ENABLE", True):
with patch("app.chain.message.asyncio.run_coroutine_threadsafe") as run_task:
with patch(
"app.chain.message.TransferHistoryOper"
) as history_oper_cls, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe"
) as run_task:
history_oper_cls.return_value.get.return_value = history
with patch.object(chain, "post_message") as post_message:
chain._handle_callback(
text="CALLBACK:transfer_ai_retry_34",