From 460d71651230bb6d0bff6d233d7f3adcb3c55036 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 29 Apr 2026 22:16:04 +0800 Subject: [PATCH] feat: add batch AI re-organize for transfer history and search result recommendation - Implement batch AI re-organize endpoint for transfer history with progress tracking - Add batch_manual_transfer_redo system task template and prompt generation - Refactor agent_manager to support generic background prompt execution - Add AIRecommendChain for search result recommendation using agent background prompt - Update search endpoints to use new AIRecommendChain and remove legacy code - Enhance test cases for batch manual transfer redo - Minor code cleanup and style fixes --- app/agent/__init__.py | 229 ++---------- app/agent/prompt/System Tasks.yaml | 42 +++ app/agent/tools/impl/query_schedulers.py | 3 +- app/agent/tools/impl/run_scheduler.py | 3 +- app/agent/tools/impl/transfer_file.py | 3 +- app/api/endpoints/history.py | 203 ++++++++++- app/api/endpoints/search.py | 26 +- app/chain/{ai_recommend.py => agent.py} | 141 +++----- app/chain/message.py | 369 +++++++++++--------- app/chain/skills.py | 8 +- app/chain/transfer.py | 354 ++++++++++++------- app/core/config.py | 2 +- app/schemas/history.py | 6 +- app/startup/agent_initializer.py | 21 +- tests/test_agent_prompt_style.py | 15 + tests/test_transfer_failed_retry_buttons.py | 27 +- 16 files changed, 821 insertions(+), 631 deletions(-) rename app/chain/{ai_recommend.py => agent.py} (67%) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 2677f632..4631900f 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -32,7 +32,6 @@ from app.agent.runtime import agent_runtime_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase from app.core.config import settings -from app.db.transferhistory_oper import TransferHistoryOper from app.helper.llm import LLMHelper from app.log import logger from app.schemas import Notification, NotificationType @@ -731,21 +730,12 @@ class AgentManager: 同一会话的消息按顺序排队处理,不同会话之间互不影响。 """ - # 批量重试整理的等待时间(秒),同一批次内的失败记录会合并为一次agent调用 - RETRY_TRANSFER_DEBOUNCE_SECONDS = 300 - def __init__(self): self.active_agents: Dict[str, MoviePilotAgent] = {} # 每个会话的消息队列 self._session_queues: Dict[str, asyncio.Queue] = {} # 每个会话的worker任务 self._session_workers: Dict[str, asyncio.Task] = {} - # 重试整理的 debounce 缓冲区: group_key -> List[history_id] - self._retry_transfer_buffer: Dict[str, List[int]] = {} - # 重试整理的 debounce 定时器: group_key -> asyncio.TimerHandle - self._retry_transfer_timers: Dict[str, asyncio.TimerHandle] = {} - # 重试整理缓冲区锁 - self._retry_transfer_lock = asyncio.Lock() def get_session_status(self, session_id: str) -> dict[str, Any]: """获取会话当前模型与 token 使用状态。""" @@ -790,11 +780,6 @@ class AgentManager: 关闭管理器 """ await memory_manager.close() - # 取消所有重试整理的延迟定时器 - for timer in self._retry_transfer_timers.values(): - timer.cancel() - self._retry_transfer_timers.clear() - self._retry_transfer_buffer.clear() # 取消所有会话worker for task in self._session_workers.values(): task.cancel() @@ -995,67 +980,40 @@ class AgentManager: memory_manager.clear_memory(session_id, user_id) logger.info(f"会话 {session_id} 的记忆已清空") + @staticmethod + async def run_background_prompt( + message: str, + session_prefix: str = "__agent_background", + output_callback: Optional[Callable[[str], None]] = None, + suppress_user_reply: bool = False, + ) -> None: + """ + 以独立后台会话执行一段 prompt。 + """ + session_id = f"{session_prefix}_{uuid.uuid4().hex[:8]}__" + user_id = SYSTEM_INTERNAL_USER_ID + agent = MoviePilotAgent( + session_id=session_id, + user_id=user_id, + channel=None, + source=None, + username=settings.SUPERUSER, + ) + agent.output_callback = output_callback + agent.force_streaming = bool(output_callback) + agent.suppress_user_reply = suppress_user_reply + + try: + await agent.process(message) + finally: + await agent.cleanup() + memory_manager.clear_memory(session_id, user_id) + @staticmethod def _build_heartbeat_prompt() -> str: """使用程序内置 System Tasks 定义构建心跳任务提示词。""" return prompt_manager.render_system_task_message("heartbeat") - @staticmethod - def _build_retry_transfer_template_context( - history_ids: list[int], - ) -> tuple[str, dict[str, int | str]]: - """仅负责把失败重试任务的动态数据映射成模板变量。""" - is_batch = len(history_ids) > 1 - task_type = ( - "batch_transfer_failed_retry" if is_batch else "transfer_failed_retry" - ) - template_context: dict[str, int | str] = { - "history_ids_csv": ", ".join(str(item) for item in history_ids), - "history_count": len(history_ids), - } - if not is_batch: - template_context["history_id"] = history_ids[0] - return task_type, template_context - - @staticmethod - def _build_retry_transfer_prompt( - history_ids: list[int], - ) -> str: - """根据失败记录数量构建统一的重试整理后台任务提示词。""" - task_type, template_context = AgentManager._build_retry_transfer_template_context( - history_ids - ) - return prompt_manager.render_system_task_message( - task_type, - template_context=template_context, - ) - - @staticmethod - def _build_manual_redo_template_context(history) -> dict[str, int | str]: - """仅负责把整理历史对象映射成 System Tasks 需要的模板变量。""" - src_fileitem = history.src_fileitem or {} - source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else "" - source_path = source_path or history.src or "" - season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip() - # 这里故意只做数据整形,具体行为定义全部交给内置 System Tasks YAML。 - return { - "history_id": history.id, - "current_status": "success" if history.status else "failed", - "recognized_title": history.title or "unknown", - "media_type": history.type or "unknown", - "category": history.category or "unknown", - "year": history.year or "unknown", - "season_episode": season_episode or "unknown", - "source_path": source_path or "unknown", - "source_storage": history.src_storage or "local", - "destination_path": history.dest or "unknown", - "destination_storage": history.dest_storage or "unknown", - "transfer_mode": history.mode or "unknown", - "tmdbid": history.tmdbid or "none", - "doubanid": history.doubanid or "none", - "error_message": history.errmsg or "none", - } - async def heartbeat_check_jobs(self): """ 心跳唤醒:检查并执行待处理的定时任务(Jobs)。 @@ -1097,135 +1055,6 @@ class AgentManager: except Exception as e: logger.error(f"智能体心跳唤醒失败: {e}") - async def retry_failed_transfer(self, history_id: int, group_key: str = ""): - """ - 触发智能体重新整理失败的历史记录。 - 由文件整理模块在检测到整理失败后调用。 - 同一 group_key 的失败记录会在缓冲期内合并为一次agent调用,避免重复浪费token。 - :param history_id: 失败的整理历史记录ID - :param group_key: 分组键,相同key的记录会被合并处理(如download_hash、源目录等) - """ - if not group_key: - group_key = f"_default_{history_id}" - - async with self._retry_transfer_lock: - # 将 history_id 加入缓冲区 - if group_key not in self._retry_transfer_buffer: - self._retry_transfer_buffer[group_key] = [] - if history_id not in self._retry_transfer_buffer[group_key]: - self._retry_transfer_buffer[group_key].append(history_id) - logger.info( - f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 " - f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)" - ) - - # 取消该分组的旧定时器 - if group_key in self._retry_transfer_timers: - self._retry_transfer_timers[group_key].cancel() - - # 设置新的延迟定时器 - loop = asyncio.get_running_loop() - self._retry_transfer_timers[group_key] = loop.call_later( - self.RETRY_TRANSFER_DEBOUNCE_SECONDS, - lambda gk=group_key: asyncio.ensure_future( - self._flush_retry_transfer(gk) - ), - ) - - async def _flush_retry_transfer(self, group_key: str): - """ - 延迟定时器到期后,取出该分组的所有 history_id 并合并为一次agent调用。 - """ - async with self._retry_transfer_lock: - history_ids = self._retry_transfer_buffer.pop(group_key, []) - self._retry_transfer_timers.pop(group_key, None) - - if not history_ids: - return - - session_id = f"__agent_retry_transfer_batch_{uuid.uuid4().hex[:8]}__" - user_id = SYSTEM_INTERNAL_USER_ID - - ids_str = ", ".join(str(i) for i in history_ids) - logger.info( - f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})" - ) - retry_message = self._build_retry_transfer_prompt(history_ids) - - try: - await self.process_message( - session_id=session_id, - user_id=user_id, - message=retry_message, - channel=None, - source=None, - username=settings.SUPERUSER, - ) - - # 等待消息队列处理完成 - if session_id in self._session_queues: - await self._session_queues[session_id].join() - - # 等待worker结束 - if session_id in self._session_workers: - try: - await self._session_workers[session_id] - except asyncio.CancelledError: - pass - - logger.info( - f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})" - ) - - # 用完即弃,清理资源 - await self.clear_session(session_id, user_id) - - except Exception as e: - logger.error( - f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {e}" - ) - - @staticmethod - def _build_manual_redo_prompt(history) -> str: - """ - 构建手动 AI 整理提示词。 - """ - return prompt_manager.render_system_task_message( - "manual_transfer_redo", - template_context=AgentManager._build_manual_redo_template_context(history), - ) - - async def manual_redo_transfer( - self, - history_id: int, - output_callback: Optional[Callable[[str], None]] = None, - ) -> None: - """ - 手动触发单条历史记录的 AI 整理。 - """ - session_id = f"__agent_manual_redo_{history_id}_{uuid.uuid4().hex[:8]}__" - user_id = SYSTEM_INTERNAL_USER_ID - agent = MoviePilotAgent( - session_id=session_id, - user_id=user_id, - channel=None, - source=None, - username=settings.SUPERUSER, - ) - agent.output_callback = output_callback - agent.force_streaming = True - agent.suppress_user_reply = True - - try: - history = TransferHistoryOper().get(history_id) - if not history: - raise ValueError(f"整理记录不存在: {history_id}") - - await agent.process(self._build_manual_redo_prompt(history)) - finally: - await agent.cleanup() - memory_manager.clear_memory(session_id, user_id) - # 全局智能体管理器实例 agent_manager = AgentManager() diff --git a/app/agent/prompt/System Tasks.yaml b/app/agent/prompt/System Tasks.yaml index f49464df..b68c3842 100644 --- a/app/agent/prompt/System Tasks.yaml +++ b/app/agent/prompt/System Tasks.yaml @@ -95,3 +95,45 @@ task_types: - "Do NOT reorganize blindly when media identity is uncertain." - "If the previous record was successful but obviously identified as the wrong media, still use the tool-based flow above instead of `/redo`." - "Keep the final response short and focused on outcome." + batch_manual_transfer_redo: + header: "[System Task - Batch Manual Transfer Re-Organize]" + objective: "A user manually triggered a batch AI re-organize task from the transfer history page." + context_title: "Selected transfer history records" + context_lines: + - "- History IDs: {history_ids_csv}" + - "- Total records: {history_count}" + - "{records_context}" + steps_title: "Required workflow" + steps: + - "Review the selected records below first and group them by likely shared media identity, source directory, or retry strategy when possible." + - "Use the provided record context as the primary source of truth. Call `query_transfer_history` only when you need extra confirmation." + - "For each group, decide whether the current recognition is trustworthy." + - "If multiple records clearly belong to the same movie or series, identify the media once with `recognize_media` or `search_media`, then reuse that result for the related records." + - "If a source file no longer exists or cannot be safely processed, skip that record and note the reason." + - "Before re-organizing a record, delete the old transfer history record with `delete_transfer_history` so the system will not skip the source file." + - "Then use `transfer_file` to organize the source path directly." + - "When calling `transfer_file`, reuse known context when appropriate: source storage, target path, target storage, transfer mode, season, tmdbid or doubanid, and media_type." + - "If a record is already correct and no re-organize is needed, do not perform destructive actions; simply mark it as skipped." + - "Report only the aggregate outcome, including how many records succeeded, skipped, and failed." + task_rules: + - "Do NOT assume every selected record belongs to the same media." + - "When several records obviously share the same media identity, avoid repeated `recognize_media` or `search_media` calls." + - "Process every selected record exactly once." + - "Keep the final response short and focused on the aggregate outcome." + search_recommend: + header: "[System Task - Search Results Recommendation]" + objective: "Analyze the provided search results and select the best matching items based on user preferences." + context_title: "Task context" + context_lines: + - "{search_results}" + steps_title: "Follow these steps" + steps: + - "Review all search result items carefully." + - "Evaluate each item based on the user preference criteria." + - "Select the top items that best match the preferences." + - "Return ONLY a JSON array of item indices." + task_rules: + - "Return ONLY a JSON array of index numbers, e.g., [0, 3, 1]." + - "Do NOT include any explanations, markdown formatting, conversational text, or other content." + - "Do NOT call any tools. Simply analyze and return the JSON result directly." + - "Respond in JSON format only." diff --git a/app/agent/tools/impl/query_schedulers.py b/app/agent/tools/impl/query_schedulers.py index b254dd9b..5d872b7d 100644 --- a/app/agent/tools/impl/query_schedulers.py +++ b/app/agent/tools/impl/query_schedulers.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.log import logger -from app.scheduler import Scheduler class QuerySchedulersInput(BaseModel): @@ -27,6 +26,8 @@ class QuerySchedulersTool(MoviePilotTool): async def run(self, **kwargs) -> str: logger.info(f"执行工具: {self.name}") try: + from app.scheduler import Scheduler + scheduler = Scheduler() schedulers = scheduler.list() if schedulers: diff --git a/app/agent/tools/impl/run_scheduler.py b/app/agent/tools/impl/run_scheduler.py index 0eecc517..73fbb030 100644 --- a/app/agent/tools/impl/run_scheduler.py +++ b/app/agent/tools/impl/run_scheduler.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.log import logger -from app.scheduler import Scheduler class RunSchedulerInput(BaseModel): @@ -36,6 +35,8 @@ class RunSchedulerTool(MoviePilotTool): @staticmethod def _run_scheduler_sync(job_id: str) -> tuple[bool, str]: """同步触发定时服务,避免调度器扫描阻塞事件循环。""" + from app.scheduler import Scheduler + scheduler = Scheduler() for scheduler_item in scheduler.list(): if scheduler_item.id == job_id: diff --git a/app/agent/tools/impl/transfer_file.py b/app/agent/tools/impl/transfer_file.py index c782027c..cc5ddd2f 100644 --- a/app/agent/tools/impl/transfer_file.py +++ b/app/agent/tools/impl/transfer_file.py @@ -6,7 +6,6 @@ from typing import Optional, Type from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool -from app.chain.transfer import TransferChain from app.log import logger from app.schemas import FileItem, MediaType @@ -124,6 +123,8 @@ class TransferFileTool(MoviePilotTool): if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" + from app.chain.transfer import TransferChain + state, errormsg = TransferChain().manual_transfer( fileitem=fileitem, target_storage=target_storage, diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index fa048c78..f42c4916 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -1,14 +1,15 @@ import asyncio import time +from pathlib import Path from typing import List, Any, Optional import jieba from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from pathlib import Path from app import schemas +from app.agent import prompt_manager, agent_manager from app.chain.storage import StorageChain from app.core.config import settings, global_vars from app.core.event import eventmanager @@ -24,13 +25,99 @@ from app.schemas.types import EventType router = APIRouter() -def _start_ai_redo_task(history_id: int, progress_key: str): - from app.agent import agent_manager +def normalize_history_ids(history_ids: list[int]) -> list[int]: + """对输入的历史记录 ID 列表进行规范化处理,去除重复项并保持原有顺序。""" + normalized_ids: list[int] = [] + for history_id in history_ids: + if history_id not in normalized_ids: + normalized_ids.append(history_id) + return normalized_ids + +def build_manual_redo_template_context(history: TransferHistory) -> dict[str, int | str]: + """仅负责把整理历史对象映射成 System Tasks 需要的模板变量。""" + src_fileitem = history.src_fileitem or {} + source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else "" + source_path = source_path or history.src or "" + season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip() + return { + "history_id": history.id, + "current_status": "success" if history.status else "failed", + "recognized_title": history.title or "unknown", + "media_type": history.type or "unknown", + "category": history.category or "unknown", + "year": history.year or "unknown", + "season_episode": season_episode or "unknown", + "source_path": source_path or "unknown", + "source_storage": history.src_storage or "local", + "destination_path": history.dest or "unknown", + "destination_storage": history.dest_storage or "unknown", + "transfer_mode": history.mode or "unknown", + "tmdbid": history.tmdbid or "none", + "doubanid": history.doubanid or "none", + "error_message": history.errmsg or "none", + } + + +def format_manual_redo_record_context(history: Any) -> str: + """把单条整理记录格式化为批量任务可直接消费的上下文块。""" + context = build_manual_redo_template_context(history) + return "\n".join( + [ + f"Record #{context['history_id']}:", + f"- Current status: {context['current_status']}", + f"- Current recognized title: {context['recognized_title']}", + f"- Media type: {context['media_type']}", + f"- Category: {context['category']}", + f"- Year: {context['year']}", + f"- Season/Episode: {context['season_episode']}", + f"- Source path: {context['source_path']}", + f"- Source storage: {context['source_storage']}", + f"- Destination path: {context['destination_path']}", + f"- Destination storage: {context['destination_storage']}", + f"- Transfer mode: {context['transfer_mode']}", + f"- Current TMDB ID: {context['tmdbid']}", + f"- Current Douban ID: {context['doubanid']}", + f"- Error message: {context['error_message']}", + ] + ) + + +def build_manual_redo_prompt(history: Any) -> str: + """构建手动 AI 整理提示词。""" + return prompt_manager.render_system_task_message( + "manual_transfer_redo", + template_context=build_manual_redo_template_context(history), + ) + + +def build_batch_manual_redo_template_context( + histories: list[Any], +) -> dict[str, int | str]: + """仅负责把多条整理历史对象映射成批量 System Tasks 需要的模板变量。""" + return { + "history_ids_csv": ", ".join(str(history.id) for history in histories), + "history_count": len(histories), + "records_context": "\n\n".join( + format_manual_redo_record_context(history) for history in histories + ), + } + + +def build_batch_manual_redo_prompt(histories: list[Any]) -> str: + """构建批量手动 AI 整理提示词。""" + return prompt_manager.render_system_task_message( + "batch_manual_transfer_redo", + template_context=build_batch_manual_redo_template_context(histories), + ) + + +def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str): + """在后台线程中启动单条 AI 重新整理任务,并通过 ProgressHelper 实时更新进度。""" progress = ProgressHelper(progress_key) progress.start() progress.update( - text=f"智能助正在准备整理记录 #{history_id} ...", + text=f"智能助手正在准备整理记录 #{history_id} ...", data={"history_id": history_id, "success": True}, ) @@ -39,9 +126,11 @@ def _start_ai_redo_task(history_id: int, progress_key: str): async def runner(): try: - await agent_manager.manual_redo_transfer( - history_id=history_id, + await agent_manager.run_background_prompt( + message=prompt, + session_prefix=f"__agent_manual_redo_{history_id}", output_callback=update_output, + suppress_user_reply=True, ) progress.update( text="智能助手整理完成", @@ -63,6 +152,50 @@ def _start_ai_redo_task(history_id: int, progress_key: str): asyncio.run_coroutine_threadsafe(runner(), global_vars.loop) +def _start_batch_ai_redo_task( + history_ids: list[int], + prompt: str, + progress_key: str, +): + """在后台线程中启动批量 AI 重新整理任务,并通过 ProgressHelper 实时更新进度。""" + progress = ProgressHelper(progress_key) + progress.start() + progress.update( + text=f"智能助手正在准备批量整理 {len(history_ids)} 条记录 ...", + data={"history_ids": history_ids, "success": True}, + ) + + def update_output(text: str): + progress.update(text=text, data={"history_ids": history_ids}) + + async def runner(): + try: + await agent_manager.run_background_prompt( + message=prompt, + session_prefix="__agent_manual_redo_batch", + output_callback=update_output, + suppress_user_reply=True, + ) + progress.update( + text="智能助手批量整理完成", + data={"history_ids": history_ids, "success": True, "completed": True}, + ) + except Exception as e: + progress.update( + text=f"智能助手批量整理失败:{str(e)}", + data={ + "history_ids": history_ids, + "success": False, + "completed": True, + "error": str(e), + }, + ) + finally: + progress.end() + + asyncio.run_coroutine_threadsafe(runner(), global_vars.loop) + + @router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory]) async def download_history(page: Optional[int] = 1, count: Optional[int] = 30, @@ -159,9 +292,9 @@ def delete_transfer_history(history_in: schemas.TransferHistory, @router.post("/transfer/{history_id}/ai-redo", summary="智能助手重新整理", response_model=schemas.Response) def ai_redo_transfer_history( - history_id: int, - db: Session = Depends(get_db), - _: User = Depends(get_current_active_superuser), + history_id: int, + db: Session = Depends(get_db), + _: User = Depends(get_current_active_superuser), ) -> Any: """ 手动触发单条历史记录的 AI 重新整理,并返回进度键。 @@ -173,12 +306,62 @@ def ai_redo_transfer_history( if not history: return schemas.Response(success=False, message="整理记录不存在") + prompt = build_manual_redo_prompt(history) progress_key = f"ai_redo_transfer_{history_id}_{int(time.time() * 1000)}" - _start_ai_redo_task(history_id=history_id, progress_key=progress_key) + _start_ai_redo_task( + history_id=history_id, + prompt=prompt, + progress_key=progress_key, + ) return schemas.Response(success=True, data={"progress_key": progress_key}) +@router.post("/transfer/ai-redo", summary="智能助手批量重新整理", response_model=schemas.Response) +def batch_ai_redo_transfer_history( + payload: schemas.BatchTransferHistoryRedoRequest, + db: Session = Depends(get_db), + _: User = Depends(get_current_active_superuser), +) -> Any: + """ + 手动触发多条历史记录的 AI 批量重新整理,并返回进度键。 + """ + if not settings.AI_AGENT_ENABLE: + return schemas.Response(success=False, message="MoviePilot智能助手未启用") + + history_ids = normalize_history_ids(payload.history_ids) + if not history_ids: + return schemas.Response(success=False, message="未提供有效的整理记录") + + histories = [] + missing_ids = [] + for history_id in history_ids: + history = TransferHistory.get(db, history_id) + if not history: + missing_ids.append(history_id) + continue + histories.append(history) + + if missing_ids: + return schemas.Response( + success=False, + message="整理记录不存在: " + ", ".join(str(history_id) for history_id in missing_ids), + ) + + prompt = build_batch_manual_redo_prompt(histories) + progress_key = f"ai_redo_transfer_batch_{int(time.time() * 1000)}" + _start_batch_ai_redo_task( + history_ids=history_ids, + prompt=prompt, + progress_key=progress_key, + ) + + return schemas.Response( + success=True, + data={"progress_key": progress_key, "history_ids": history_ids}, + ) + + @router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response) async def empty_transfer_history(db: AsyncSession = Depends(get_async_db), _: User = Depends(get_current_active_superuser_async)) -> Any: diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index 42239760..2c33106b 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -5,9 +5,9 @@ from fastapi import APIRouter, Depends, Body, Request from fastapi.responses import StreamingResponse from app import schemas +from app.chain.agent import AIRecommendChain from app.chain.media import MediaChain from app.chain.search import SearchChain -from app.chain.ai_recommend import AIRecommendChain from app.core.config import settings from app.core.event import eventmanager from app.core.metainfo import MetaInfo @@ -73,7 +73,6 @@ async def search_by_id_stream(request: Request, """ 根据TMDBID/豆瓣ID渐进式搜索站点资源,返回格式为SSE """ - AIRecommendChain().cancel_ai_recommend() media_type = MediaType(mtype) if mtype else None media_season = int(season) if season else None @@ -206,8 +205,7 @@ async def search_by_id(mediaid: str, 根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi: """ # 取消正在运行的AI推荐(会清除数据库缓存) - AIRecommendChain().cancel_ai_recommend() - + if mtype: media_type = MediaType(mtype) else: @@ -332,7 +330,6 @@ async def search_by_title_stream(request: Request, """ 根据名称渐进式模糊搜索站点资源,返回格式为SSE """ - AIRecommendChain().cancel_ai_recommend() event_source = SearchChain().async_search_by_title_stream( title=keyword, @@ -352,8 +349,7 @@ async def search_by_title(keyword: Optional[str] = None, 根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源 """ # 取消正在运行的AI推荐并清除数据库缓存 - AIRecommendChain().cancel_ai_recommend() - + torrents = await SearchChain().async_search_by_title( title=keyword, page=page, sites=_parse_site_list(sites), @@ -396,9 +392,9 @@ async def recommend_search_results( return schemas.Response(success=False, message="没有可用的搜索结果", data={ "status": "error" }) - + recommend_chain = AIRecommendChain() - + # 如果是强制模式,先取消并清除旧结果,然后直接启动新任务 if force: # 检查功能是否启用 @@ -413,7 +409,7 @@ async def recommend_search_results( return schemas.Response(success=True, data={ "status": "running" }) - + # 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测) if check_only: # 返回当前运行状态,不做任何任务启动或取消操作 @@ -423,14 +419,14 @@ async def recommend_search_results( error_msg = current_status.pop("error", "未知错误") return schemas.Response(success=False, message=error_msg, data=current_status) return schemas.Response(success=True, data=current_status) - + # 获取当前状态(会检测请求是否变化) status_data = recommend_chain.get_status(filtered_indices, len(results)) - + # 如果功能未启用,直接返回禁用状态 if status_data.get("status") == "disabled": return schemas.Response(success=True, data=status_data) - + # 如果是空闲状态,启动新任务 if status_data["status"] == "idle": recommend_chain.start_recommend_task(filtered_indices, len(results), results) @@ -438,11 +434,11 @@ async def recommend_search_results( return schemas.Response(success=True, data={ "status": "running" }) - + # 如果有错误,将错误信息放到message中 if status_data.get("status") == "error": error_msg = status_data.pop("error", "未知错误") return schemas.Response(success=False, message=error_msg, data=status_data) - + # 返回当前状态 return schemas.Response(success=True, data=status_data) diff --git a/app/chain/ai_recommend.py b/app/chain/agent.py similarity index 67% rename from app/chain/ai_recommend.py rename to app/chain/agent.py index 6b1d5fb1..3188c6d1 100644 --- a/app/chain/ai_recommend.py +++ b/app/chain/agent.py @@ -1,13 +1,13 @@ -import re -from typing import List, Optional, Dict, Any import asyncio import hashlib import json +import re +from typing import Any, Dict, List, Optional +from app.agent import agent_manager, prompt_manager from app.chain import ChainBase from app.core.config import settings from app.log import logger -from app.utils.common import log_execution_time from app.utils.singleton import Singleton from app.utils.string import StringUtils @@ -16,17 +16,16 @@ class AIRecommendChain(ChainBase, metaclass=Singleton): """ AI推荐处理链,单例运行 用于基于搜索结果的AI智能推荐 + 使用 agent_manager.run_background_prompt 统一后台任务机制 """ - # 缓存文件名 __ai_indices_cache_file = "__ai_recommend_indices__" - # AI推荐状态 _ai_recommend_running = False _ai_recommend_task: Optional[asyncio.Task] = None - _current_request_hash: Optional[str] = None # 当前请求的哈希值 - _ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存(索引列表) - _ai_recommend_error: Optional[str] = None # AI推荐错误信息 + _current_request_hash: Optional[str] = None + _ai_recommend_result: Optional[List[int]] = None + _ai_recommend_error: Optional[str] = None @staticmethod def _calculate_request_hash( @@ -53,7 +52,6 @@ class AIRecommendChain(ChainBase, metaclass=Singleton): def _build_status(self) -> Dict[str, Any]: """ 构建AI推荐状态字典 - :return: 状态字典 """ if not self.is_enabled: return {"status": "disabled"} @@ -61,13 +59,11 @@ class AIRecommendChain(ChainBase, metaclass=Singleton): if self._ai_recommend_running: return {"status": "running"} - # 尝试从数据库加载缓存 if self._ai_recommend_result is None: cached_indices = self.load_cache(self.__ai_indices_cache_file) if cached_indices is not None: self._ai_recommend_result = cached_indices - # 只要有结果,始终返回completed状态和数据 if self._ai_recommend_result is not None: return {"status": "completed", "results": self._ai_recommend_result} @@ -89,76 +85,16 @@ class AIRecommendChain(ChainBase, metaclass=Singleton): 获取AI推荐状态并检查请求是否变化(用于首次请求或force模式) 如果请求变化(筛选条件变化),返回idle状态 """ - # 计算当前请求的hash request_hash = self._calculate_request_hash( filtered_indices, search_results_count ) - - # 检查请求是否变化 is_same_request = request_hash == self._current_request_hash - # 如果请求变化了(筛选条件改变),返回idle状态 if not is_same_request: return {"status": "idle"} if self.is_enabled else {"status": "disabled"} - # 请求未变化,返回当前实际状态 return self._build_status() - @log_execution_time(logger=logger) - async def async_ai_recommend(self, items: List[str], preference: str = None) -> str: - """ - AI推荐 - :param items: 候选资源列表(JSON字符串格式) - :param preference: 用户偏好(可选) - :return: AI返回的推荐结果 - """ - # 设置运行状态 - self._ai_recommend_running = True - try: - # 导入LLMHelper - from app.helper.llm import LLMHelper - - # 获取LLM实例 - llm = LLMHelper.get_llm() - - # 构建提示词 - user_preference = ( - preference - or settings.AI_RECOMMEND_USER_PREFERENCE - or "Prefer high-quality resources with more seeders" - ) - - # 添加指令 - instruction = """ -Task: Select the best matching items from the list based on user preferences. - -Each item contains: -- index: Item number -- title: Full torrent title -- size: File size -- seeders: Number of seeders - -Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text. -""" - message = ( - f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n" - + "\n".join(items) - ) - - # 调用LLM - response = await llm.ainvoke(message) - return response.content - - except ValueError as e: - logger.error(f"AI推荐配置错误: {e}") - raise - except Exception as e: - raise - finally: - # 清除运行状态 - self._ai_recommend_running = False - self._ai_recommend_task = None - def is_ai_recommend_running(self) -> bool: """ 检查AI推荐是否正在运行 @@ -186,44 +122,34 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do ) -> None: """ 启动AI推荐任务 + 使用 agent_manager.run_background_prompt 后台Agent机制执行推荐 :param filtered_indices: 筛选后的索引列表 :param search_results_count: 搜索结果总数 :param results: 搜索结果列表 """ - # 防护检查:确保AI推荐功能已启用 if not self.is_enabled: logger.warning("AI推荐功能未启用,跳过任务执行") return - # 计算新请求的哈希值 new_request_hash = self._calculate_request_hash( filtered_indices, search_results_count ) - # 如果请求变化了,取消旧任务 if new_request_hash != self._current_request_hash: self.cancel_ai_recommend() - - # 更新请求哈希值 self._current_request_hash = new_request_hash - - # 重置状态 self._ai_recommend_result = None self._ai_recommend_error = None - # 启动新任务 async def run_recommend(): - # 获取当前任务对象,用于在finally中比对 current_task = asyncio.current_task() try: self._ai_recommend_running = True - # 准备数据 items = [] valid_indices = [] max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50 - # 如果提供了筛选索引,先筛选结果;否则使用所有结果 if filtered_indices is not None and len(filtered_indices) > 0: results_to_process = [ results[i] @@ -259,27 +185,54 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do self._ai_recommend_error = "没有可用于AI推荐的资源" return - # 调用AI推荐 - ai_response = await self.async_ai_recommend(items) + user_preference = ( + settings.AI_RECOMMEND_USER_PREFERENCE + or "Prefer high-quality resources with more seeders" + ) + + search_results_text = "User Preference: {preference}\n\nCandidate Resources:\n{items}".format( + preference=user_preference, items="\n".join(items) + ) + + prompt = prompt_manager.render_system_task_message( + "search_recommend", + template_context={"search_results": search_results_text}, + ) + + full_output = [""] + + def on_output(text: str): + full_output[0] = text + + await agent_manager.run_background_prompt( + message=prompt, + session_prefix="__agent_search_recommend", + output_callback=on_output, + suppress_user_reply=True, + ) + + ai_response = full_output[0] + if not ai_response: + self._ai_recommend_error = "AI推荐未返回结果" + return - # 解析AI返回的索引 try: - # 使用正则提取JSON数组(非贪婪模式,避免匹配多个数组) - json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL) + json_match = re.search(r"\[.*?]", ai_response, re.DOTALL) if not json_match: - raise ValueError(ai_response) - + raise ValueError(f"无法从响应中提取JSON数组: {ai_response}") + ai_indices = json.loads(json_match.group()) if not isinstance(ai_indices, list): raise ValueError(f"AI返回格式错误: {ai_response}") - # 映射回原始索引 if filtered_indices: original_indices = [ filtered_indices[valid_indices[i]] for i in ai_indices if i < len(valid_indices) - and 0 <= filtered_indices[valid_indices[i]] < len(results) + and 0 + <= filtered_indices[valid_indices[i]] + < len(results) ] else: original_indices = [ @@ -289,10 +242,7 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do and 0 <= valid_indices[i] < len(results) ] - # 只返回索引列表,不返回完整数据 self._ai_recommend_result = original_indices - - # 保存到数据库 self.save_cache(original_indices, self.__ai_indices_cache_file) logger.info(f"AI推荐完成: {len(original_indices)}项") @@ -308,11 +258,8 @@ Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do logger.error(f"AI推荐任务失败: {e}") self._ai_recommend_error = str(e) finally: - # 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态 - # 如果任务被取消并启动了新任务,self._ai_recommend_task 已经指向新任务,不应重置 if self._ai_recommend_task == current_task: self._ai_recommend_running = False self._ai_recommend_task = None - # 创建并启动任务 self._ai_recommend_task = asyncio.create_task(run_recommend()) diff --git a/app/chain/message.py b/app/chain/message.py index 306d988f..c109025c 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -1,16 +1,15 @@ import asyncio +import base64 import mimetypes import re import time +import uuid from datetime import datetime, timedelta from pathlib import Path from typing import Any, Optional, Dict, Union, List from urllib.parse import unquote, urlparse -import uuid -import base64 - -from app.agent import agent_manager +from app.agent import agent_manager, prompt_manager from app.chain import ChainBase from app.chain.interaction import ( MediaInteractionChain, @@ -20,6 +19,8 @@ from app.chain.interaction import ( from app.chain.skills import SkillsChain, skills_interaction_manager from app.chain.transfer import TransferChain from app.core.config import settings, global_vars +from app.db.models import TransferHistory +from app.db.transferhistory_oper import TransferHistoryOper from app.helper.llm import LLMHelper from app.helper.voice import VoiceHelper from app.log import logger @@ -92,17 +93,17 @@ class MessageChain(ChainBase): ) def handle_message( - self, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - text: str, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - images: Optional[List[CommingMessage.MessageImage]] = None, - audio_refs: Optional[List[str]] = None, - files: Optional[List[CommingMessage.MessageAttachment]] = None, + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + images: Optional[List[CommingMessage.MessageImage]] = None, + audio_refs: Optional[List[str]] = None, + files: Optional[List[CommingMessage.MessageAttachment]] = None, ) -> None: """ 识别消息内容,执行操作 @@ -171,21 +172,21 @@ class MessageChain(ChainBase): if skills_interaction_manager.get_by_user(userid): if SkillsChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, + channel=channel, + source=source, + userid=userid, + username=username, + text=text, ): return if media_interaction_manager.get_by_user(userid): if MediaInteractionChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, + channel=channel, + source=source, + userid=userid, + username=username, + text=text, ): return @@ -202,8 +203,8 @@ class MessageChain(ChainBase): return if ( - settings.AI_AGENT_ENABLE - and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input) + settings.AI_AGENT_ENABLE + and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input) ): self._handle_ai_message( text=text, @@ -217,11 +218,11 @@ class MessageChain(ChainBase): return if MediaInteractionChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, + channel=channel, + source=source, + userid=userid, + username=username, + text=text, ): return @@ -236,14 +237,14 @@ class MessageChain(ChainBase): ) def _handle_callback( - self, - text: str, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, + self, + text: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, ) -> None: """ 处理按钮回调 @@ -254,44 +255,44 @@ class MessageChain(ChainBase): logger.info(f"处理按钮回调:{callback_data}") if self._handle_transfer_callback( - callback_data=callback_data, - channel=channel, - source=source, - userid=userid, - username=username, + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, ): return if SkillsChain().handle_callback_interaction( - callback_data=callback_data, - channel=channel, - source=source, - userid=userid, - username=username, - original_message_id=original_message_id, - original_chat_id=original_chat_id, + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, ): return if MediaInteractionChain().handle_callback_interaction( - callback_data=callback_data, - channel=channel, - source=source, - userid=userid, - username=username, - original_message_id=original_message_id, - original_chat_id=original_chat_id, + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, ): return if self._handle_agent_choice_callback( - callback_data=callback_data, - channel=channel, - source=source, - userid=userid, - username=username, - original_message_id=original_message_id, - original_chat_id=original_chat_id, + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, ): return @@ -327,14 +328,14 @@ class MessageChain(ChainBase): @staticmethod def _parse_transfer_callback( - callback_data: str, + callback_data: str, ) -> Optional[tuple[str, int]]: """ 解析整理失败通知按钮回调。 """ for prefix, action in ( - ("transfer_retry_", "retry"), - ("transfer_ai_retry_", "ai_retry"), + ("transfer_retry_", "retry"), + ("transfer_ai_retry_", "ai_retry"), ): if callback_data.startswith(prefix): history_id = callback_data.replace(prefix, "", 1) @@ -343,12 +344,12 @@ class MessageChain(ChainBase): return None def _handle_transfer_callback( - self, - callback_data: str, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, ) -> bool: """ 处理整理失败通知中的重试类按钮。 @@ -378,7 +379,7 @@ class MessageChain(ChainBase): @staticmethod def _parse_agent_choice_callback( - callback_data: str, + callback_data: str, ) -> Optional[tuple[str, int]]: """ 解析 Agent 按钮选择回调。 @@ -401,14 +402,14 @@ class MessageChain(ChainBase): return request_id, int(option_index) def _handle_agent_choice_callback( - self, - callback_data: str, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, ) -> bool: """ 将 Agent 按钮选择回传为同一会话中的下一条用户消息。 @@ -465,14 +466,14 @@ class MessageChain(ChainBase): return True def _update_interaction_message_feedback( - self, - channel: MessageChannel, - source: str, - original_message_id: Optional[Union[str, int]], - original_chat_id: Optional[str], - prompt: str, - selected_label: str, - title: Optional[str] = None, + self, + channel: MessageChannel, + source: str, + original_message_id: Optional[Union[str, int]], + original_chat_id: Optional[str], + prompt: str, + selected_label: str, + title: Optional[str] = None, ) -> None: """ 在用户点击交互按钮后,立即更新原消息,明确显示已选择的内容。 @@ -494,12 +495,12 @@ class MessageChain(ChainBase): ) def _retry_transfer_history( - self, - history_id: int, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, + self, + history_id: int, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, ) -> None: """ 立即重新整理一条失败的整理记录。 @@ -541,16 +542,46 @@ class MessageChain(ChainBase): ) def _take_over_transfer_history_by_ai( - self, - history_id: int, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, + self, + history_id: int, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, ) -> None: """ 由智能助手接管一条失败的整理记录。 """ + + def __build_manual_redo_prompt(his: TransferHistory) -> str: + """构建手动 AI 整理提示词。""" + + src_fileitem = his.src_fileitem or {} + source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else "" + source_path = source_path or his.src or "" + season_episode = f"{his.seasons or ''}{his.episodes or ''}".strip() + template_context = { + "his_id": his.id, + "current_status": "success" if his.status else "failed", + "recognized_title": his.title or "unknown", + "media_type": his.type or "unknown", + "category": his.category or "unknown", + "year": his.year or "unknown", + "season_episode": season_episode or "unknown", + "source_path": source_path or "unknown", + "source_storage": his.src_storage or "local", + "destination_path": his.dest or "unknown", + "destination_storage": his.dest_storage or "unknown", + "transfer_mode": his.mode or "unknown", + "tmdbid": his.tmdbid or "none", + "doubanid": his.doubanid or "none", + "error_message": his.errmsg or "none", + } + return prompt_manager.render_system_task_message( + "manual_transfer_redo", + template_context=template_context, + ) + if not settings.AI_AGENT_ENABLE: self.post_message( Notification( @@ -563,6 +594,23 @@ class MessageChain(ChainBase): ) return + history = TransferHistoryOper().get(history_id) + if not history: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="重新整理失败", + text=f"整理记录 #{history_id} 不存在", + link=settings.MP_DOMAIN("#/history"), + ) + ) + return + + redo_prompt = __build_manual_redo_prompt(history) + self.post_message( Notification( channel=channel, @@ -583,9 +631,11 @@ class MessageChain(ChainBase): final_output = text_output or "" try: - await agent_manager.manual_redo_transfer( - history_id=history_id, + await agent_manager.run_background_prompt( + message=redo_prompt, + session_prefix=f"__agent_manual_redo_{history_id}", output_callback=_capture_output, + suppress_user_reply=True, ) await self.async_post_message( Notification( @@ -595,7 +645,7 @@ class MessageChain(ChainBase): username=username, title="智能助手整理完成", text=final_output.strip() - or f"整理记录 #{history_id} 已由智能助手处理完成。", + or f"整理记录 #{history_id} 已由智能助手处理完成。", link=settings.MP_DOMAIN("#/history"), ) ) @@ -650,12 +700,12 @@ class MessageChain(ChainBase): self._user_sessions[userid] = (session_id, datetime.now()) def _record_user_message( - self, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - text: str, + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, ) -> None: """ 保存一条用户消息到消息历史与数据库。 @@ -690,10 +740,10 @@ class MessageChain(ChainBase): return False def remote_clear_session( - self, - channel: MessageChannel, - userid: Union[str, int], - source: Optional[str] = None, + self, + channel: MessageChannel, + userid: Union[str, int], + source: Optional[str] = None, ): """ 清除用户会话(远程命令接口) @@ -735,10 +785,10 @@ class MessageChain(ChainBase): ) def remote_stop_agent( - self, - channel: MessageChannel, - userid: Union[str, int], - source: Optional[str] = None, + self, + channel: MessageChannel, + userid: Union[str, int], + source: Optional[str] = None, ): """ 应急停止当前正在执行的Agent推理(远程命令接口)。 @@ -805,7 +855,7 @@ class MessageChain(ChainBase): f"({context_ratio * 100:.2f}%)" if context_ratio is not None else f"{cls._format_token_count(last_input_tokens)} / " - f"{cls._format_token_count(context_window_tokens)}" + f"{cls._format_token_count(context_window_tokens)}" ) else: context_usage_text = "暂无模型调用数据" @@ -825,10 +875,10 @@ class MessageChain(ChainBase): return "\n".join(lines) def remote_session_status( - self, - channel: MessageChannel, - userid: Union[str, int], - source: Optional[str] = None, + self, + channel: MessageChannel, + userid: Union[str, int], + source: Optional[str] = None, ): """查询当前用户的智能体会话状态。""" session_info = self._user_sessions.get(userid) @@ -856,15 +906,15 @@ class MessageChain(ChainBase): ) def _handle_ai_message( - self, - text: str, - channel: MessageChannel, - source: str, - userid: Union[str, int], - username: str, - images: Optional[List[CommingMessage.MessageImage]] = None, - files: Optional[List[CommingMessage.MessageAttachment]] = None, - session_id: Optional[str] = None, + self, + text: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + images: Optional[List[CommingMessage.MessageImage]] = None, + files: Optional[List[CommingMessage.MessageAttachment]] = None, + session_id: Optional[str] = None, ) -> None: """ 处理AI智能体消息 @@ -928,10 +978,10 @@ class MessageChain(ChainBase): elif images: image_attachments = self._build_image_attachments(images) if ( - original_images - and not image_attachments - and not user_message - and not files + original_images + and not image_attachments + and not user_message + and not files ): self.post_message( Notification( @@ -986,7 +1036,7 @@ class MessageChain(ChainBase): ) def _transcribe_audio_refs( - self, audio_refs: List[str], channel: MessageChannel, source: str + self, audio_refs: List[str], channel: MessageChannel, source: str ) -> Optional[str]: """ 下载并识别语音消息,仅处理当前已接入的渠道。 @@ -1119,10 +1169,10 @@ class MessageChain(ChainBase): return default def _download_attachments_to_data_urls( - self, - attachments: List[CommingMessage.MessageImage], - channel: MessageChannel, - source: str, + self, + attachments: List[CommingMessage.MessageImage], + channel: MessageChannel, + source: str, ) -> Optional[List[str]]: """ 下载可直接提供给 LLM 的附件内容,并统一转换为 data URL。 @@ -1147,7 +1197,7 @@ class MessageChain(ChainBase): if base64_data: data_urls.append(f"data:image/jpeg;base64,{base64_data}") elif attachment_ref.startswith( - "wxwork://media_id/" + "wxwork://media_id/" ) or attachment_ref.startswith( "wxbot://image/" ): @@ -1208,7 +1258,7 @@ class MessageChain(ChainBase): return data_urls if data_urls else None def _build_image_attachments( - self, images: List[CommingMessage.MessageImage] + self, images: List[CommingMessage.MessageImage] ) -> List[CommingMessage.MessageAttachment]: """ 将图片引用转换为附件描述,以便按文件方式交给 Agent 处理。 @@ -1235,11 +1285,11 @@ class MessageChain(ChainBase): return attachments def _prepare_agent_files( - self, - session_id: str, - files: Optional[List[CommingMessage.MessageAttachment]], - channel: MessageChannel, - source: str, + self, + session_id: str, + files: Optional[List[CommingMessage.MessageAttachment]], + channel: MessageChannel, + source: str, ) -> Optional[List[dict]]: """ 下载用户上传的附件,落盘到临时目录,并生成 Agent 可消费的文件描述。 @@ -1286,7 +1336,7 @@ class MessageChain(ChainBase): return prepared_files or None def _download_message_file_bytes( - self, file_ref: str, channel: MessageChannel, source: str + self, file_ref: str, channel: MessageChannel, source: str ) -> Optional[bytes]: """ 下载消息附件的原始字节内容。 @@ -1359,11 +1409,11 @@ class MessageChain(ChainBase): return None def _save_agent_attachment( - self, - session_id: str, - filename: Optional[str], - content: bytes, - mime_type: Optional[str] = None, + self, + session_id: str, + filename: Optional[str], + content: bytes, + mime_type: Optional[str] = None, ) -> Path: """ 将用户上传文件写入临时目录,并返回本地路径。 @@ -1379,7 +1429,7 @@ class MessageChain(ChainBase): @staticmethod def _sanitize_attachment_name( - filename: Optional[str], mime_type: Optional[str] = None + filename: Optional[str], mime_type: Optional[str] = None ) -> str: """ 规范化附件文件名,避免路径穿越和非法字符。 @@ -1449,5 +1499,6 @@ class MessageChain(ChainBase): return None try: return base64.b64decode(payload) - except Exception: + except Exception as e: + logger.error(e) return None diff --git a/app/chain/skills.py b/app/chain/skills.py index 7bafae2d..430c6108 100644 --- a/app/chain/skills.py +++ b/app/chain/skills.py @@ -724,8 +724,7 @@ class SkillsChain(ChainBase): """ if request.view == "installed": title, text, buttons = self._build_installed_view( - request=request, - force_market_refresh=force_market_refresh, + request=request ) elif request.view == "market": title, text, buttons = self._build_market_view( @@ -735,7 +734,6 @@ class SkillsChain(ChainBase): elif request.view == "sources": title, text, buttons = self._build_sources_view( request=request, - force_market_refresh=force_market_refresh, ) else: title, text, buttons = self._build_root_view( @@ -808,8 +806,7 @@ class SkillsChain(ChainBase): def _build_installed_view( self, - request: PendingSkillsInteraction, - force_market_refresh: bool = False, # noqa: ARG002 + request: PendingSkillsInteraction ) -> Tuple[str, str, Optional[List[List[dict]]]]: """ 构建已安装技能视图,列出来源和可删除状态。 @@ -971,7 +968,6 @@ class SkillsChain(ChainBase): def _build_sources_view( self, request: PendingSkillsInteraction, - force_market_refresh: bool = False, # noqa: ARG002 ) -> Tuple[str, str, Optional[List[List[dict]]]]: """ 构建技能源管理视图,提供自定义 GitHub 源的增删入口。 diff --git a/app/chain/transfer.py b/app/chain/transfer.py index 855695ab..22f509b8 100755 --- a/app/chain/transfer.py +++ b/app/chain/transfer.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import List, Optional, Tuple, Union, Dict, Callable from app import schemas +from app.agent import prompt_manager, agent_manager from app.chain import ChainBase from app.chain.media import MediaChain from app.chain.storage import StorageChain @@ -162,10 +163,10 @@ class JobManager: else: # 不重复添加任务 if any( - [ - t.fileitem == task.fileitem - for t in self._job_view[__mediaid__].tasks - ] + [ + t.fileitem == task.fileitem + for t in self._job_view[__mediaid__].tasks + ] ): logger.debug(f"任务 {task.fileitem.name} 已存在,跳过重复添加") return False @@ -301,7 +302,7 @@ class JobManager: return task def __remove_task_with_job_id( - self, fileitem: FileItem + self, fileitem: FileItem ) -> Tuple[Optional[TransferJobTask], Optional[Tuple]]: """ 根据文件项移除任务,并返回任务所在的作业ID @@ -462,10 +463,10 @@ class JobManager: """ with job_lock: if any( - task.state not in {"completed", "failed"} - for job in self._job_view.values() - for task in job.tasks - if task.download_hash == download_hash + task.state not in {"completed", "failed"} + for job in self._job_view.values() + for task in job.tasks + if task.download_hash == download_hash ): return False return True @@ -476,19 +477,19 @@ class JobManager: """ with job_lock: if any( - task.state != "completed" - for job in self._job_view.values() - for task in job.tasks - if task.download_hash == download_hash + task.state != "completed" + for job in self._job_view.values() + for task in job.tasks + if task.download_hash == download_hash ): return False return True def has_tasks( - self, - meta: MetaBase, - mediainfo: Optional[MediaInfo] = None, - season: Optional[int] = None, + self, + meta: MetaBase, + mediainfo: Optional[MediaInfo] = None, + season: Optional[int] = None, ) -> bool: """ 判断作业是否还有任务正在处理 @@ -501,12 +502,12 @@ class JobManager: __metaid__ = self.__get_meta_id(meta=meta, season=season) return ( - __metaid__ in self._job_view - and len(self._job_view[__metaid__].tasks) > 0 + __metaid__ in self._job_view + and len(self._job_view[__metaid__].tasks) > 0 ) def success_tasks( - self, media: MediaInfo, season: Optional[int] = None + self, media: MediaInfo, season: Optional[int] = None ) -> List[TransferJobTask]: """ 获取作业中所有成功的任务 @@ -522,7 +523,7 @@ class JobManager: ] def all_tasks( - self, media: MediaInfo, season: Optional[int] = None + self, media: MediaInfo, season: Optional[int] = None ) -> List[TransferJobTask]: """ 获取作业中全部任务 @@ -586,7 +587,7 @@ class JobManager: return list(self._job_view.values()) def season_episodes( - self, media: MediaInfo, season: Optional[int] = None + self, media: MediaInfo, season: Optional[int] = None ) -> List[int]: """ 获取作业的季集清单 @@ -596,6 +597,107 @@ class JobManager: return self._season_episodes.get(__mediaid__) or [] +class FailedRetryScheduler: + """ + 负责失败整理记录的 debounce 聚合与 AI 重试调度。 + """ + + RETRY_TRANSFER_DEBOUNCE_SECONDS = 300 + + def __init__(self): + super().__init__() + self._retry_transfer_buffer: dict[str, list[int]] = {} + self._retry_transfer_timers: dict[str, asyncio.TimerHandle] = {} + self._retry_transfer_lock = asyncio.Lock() + + async def close(self): + async with self._retry_transfer_lock: + timers = list(self._retry_transfer_timers.values()) + self._retry_transfer_timers.clear() + self._retry_transfer_buffer.clear() + + for timer in timers: + timer.cancel() + + @staticmethod + def _build_retry_transfer_template_context( + history_ids: list[int], + ) -> tuple[str, dict[str, int | str]]: + """仅负责把失败重试任务的动态数据映射成模板变量。""" + is_batch = len(history_ids) > 1 + task_type = "batch_transfer_failed_retry" if is_batch else "transfer_failed_retry" + template_context: dict[str, int | str] = { + "history_ids_csv": ", ".join(str(item) for item in history_ids), + "history_count": len(history_ids), + } + if not is_batch: + template_context["history_id"] = history_ids[0] + return task_type, template_context + + def _build_retry_transfer_prompt(self, history_ids: list[int]) -> str: + """根据失败记录数量构建统一的重试整理后台任务提示词。""" + task_type, template_context = self._build_retry_transfer_template_context(history_ids) + return prompt_manager.render_system_task_message( + task_type, + template_context=template_context, + ) + + async def schedule_retry(self, history_id: int, group_key: str = ""): + """ + 同一 group_key 的失败记录会在缓冲期内合并为一次 agent 调用。 + """ + if not group_key: + group_key = f"_default_{history_id}" + + async with self._retry_transfer_lock: + if group_key not in self._retry_transfer_buffer: + self._retry_transfer_buffer[group_key] = [] + if history_id not in self._retry_transfer_buffer[group_key]: + self._retry_transfer_buffer[group_key].append(history_id) + logger.info( + f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 " + f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)" + ) + + if group_key in self._retry_transfer_timers: + self._retry_transfer_timers[group_key].cancel() + + loop = asyncio.get_running_loop() + self._retry_transfer_timers[group_key] = loop.call_later( + self.RETRY_TRANSFER_DEBOUNCE_SECONDS, + lambda gk=group_key: asyncio.create_task(self._flush_retry_transfer(gk)), + ) + + async def _flush_retry_transfer(self, group_key: str): + """ + 延迟定时器到期后,取出该分组的所有 history_id 并合并为一次 agent 调用。 + """ + async with self._retry_transfer_lock: + history_ids = self._retry_transfer_buffer.pop(group_key, []) + self._retry_transfer_timers.pop(group_key, None) + + if not history_ids: + return + + ids_str = ", ".join(str(item) for item in history_ids) + logger.info( + f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})" + ) + + try: + await agent_manager.run_background_prompt( + message=self._build_retry_transfer_prompt(history_ids), + session_prefix="__agent_retry_transfer_batch", + ) + logger.info( + f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})" + ) + except Exception as err: + logger.error( + f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {err}" + ) + + class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): """ 文件整理处理链 @@ -623,6 +725,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): self._transfer_interval = 15 # 事件管理器 self.jobview = JobManager() + # Agent重试管理器 + self.retry_scheduler = FailedRetryScheduler() # 转移成功的文件清单 self._success_target_files: Dict[str, List[str]] = {} # 整理进度进度 @@ -713,7 +817,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): ) def __default_callback( - self, task: TransferTask, transferinfo: TransferInfo, / + self, task: TransferTask, transferinfo: TransferInfo, / ) -> Tuple[bool, str]: """ 整理完成后处理 @@ -730,12 +834,12 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): """ # 更新文件数量 transferinfo.file_count = ( - self.jobview.count(task.mediainfo, task.meta.begin_season) or 1 + self.jobview.count(task.mediainfo, task.meta.begin_season) or 1 ) # 更新文件大小 transferinfo.total_size = ( - self.jobview.size(task.mediainfo, task.meta.begin_season) - or task.fileitem.size + self.jobview.size(task.mediainfo, task.meta.begin_season) + or task.fileitem.size ) # 更新文件清单 with job_lock: @@ -866,13 +970,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # AI智能体自动重试整理 if ( - history - and settings.AI_AGENT_ENABLE - and settings.AI_AGENT_RETRY_TRANSFER + history + and settings.AI_AGENT_ENABLE + and settings.AI_AGENT_RETRY_TRANSFER ): try: - from app.agent import agent_manager - # 使用 download_hash 或源文件父目录作为分组键, # 同一批次(如同一个种子)的失败记录会被合并为一次agent调用 group_key = ( @@ -881,7 +983,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): else "" ) asyncio.run_coroutine_threadsafe( - agent_manager.retry_failed_transfer( + self.retry_scheduler.schedule_retry( history.id, group_key=group_key ), global_vars.loop, @@ -996,11 +1098,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): if self.jobview.is_torrent_success(t.download_hash): processed_hashes.add(t.download_hash) if self._can_delete_torrent( - t.download_hash, t.downloader, transfer_exclude_words + t.download_hash, t.downloader, transfer_exclude_words ): # 移除种子及文件 if self.remove_torrents( - t.download_hash, downloader=t.downloader + t.download_hash, downloader=t.downloader ): logger.info( f"移动模式删除种子成功:{t.download_hash}" @@ -1156,7 +1258,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): logger.error(f"整理队列处理出现错误:{e} - {traceback.format_exc()}") def __handle_transfer( - self, task: TransferTask, callback: Optional[Callable] = None + self, task: TransferTask, callback: Optional[Callable] = None ) -> Optional[Tuple[bool, str]]: """ 处理整理任务 @@ -1223,13 +1325,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # AI智能体自动重试整理 if ( - his - and settings.AI_AGENT_ENABLE - and settings.AI_AGENT_RETRY_TRANSFER + his + and settings.AI_AGENT_ENABLE + and settings.AI_AGENT_RETRY_TRANSFER ): try: - from app.agent import agent_manager - # 使用 download_hash 或源文件父目录作为分组键 group_key = ( task.download_hash @@ -1238,7 +1338,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): else "" ) asyncio.run_coroutine_threadsafe( - agent_manager.retry_failed_transfer( + self.retry_scheduler.schedule_retry( his.id, group_key=group_key ), global_vars.loop, @@ -1393,8 +1493,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # 如果没有下载器监控的目录则不处理 if not any( - dir_info.monitor_type == "downloader" and dir_info.storage == "local" - for dir_info in download_dirs + dir_info.monitor_type == "downloader" and dir_info.storage == "local" + for dir_info in download_dirs ): return True @@ -1408,8 +1508,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): torrent for torrent in torrents_list if (h := torrent.hash) not in existing_hashes - # 排除多下载器返回的重复种子 - and (h not in seen and (seen.add(h) or True)) + # 排除多下载器返回的重复种子 + and (h not in seen and (seen.add(h) or True)) ] else: torrents = [] @@ -1480,7 +1580,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): fileitem=FileItem( storage="local", path=file_path.as_posix() - + ("/" if file_path.is_dir() else ""), + + ("/" if file_path.is_dir() else ""), type="dir" if not file_path.is_file() else "file", name=file_path.name, size=file_path.stat().st_size, @@ -1498,10 +1598,10 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return True def __get_trans_fileitems( - self, - fileitem: FileItem, - predicate: Optional[Callable[[FileItem, bool], bool]], - verify_file_exists: bool = True, + self, + fileitem: FileItem, + predicate: Optional[Callable[[FileItem, bool], bool]], + verify_file_exists: bool = True, ) -> List[Tuple[FileItem, bool]]: """ 获取待整理文件项列表 @@ -1541,7 +1641,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return None def _apply_predicate( - file_item: FileItem, is_bluray_dir: bool + file_item: FileItem, is_bluray_dir: bool ) -> List[Tuple[FileItem, bool]]: if predicate is None or predicate(file_item, is_bluray_dir): return [(file_item, is_bluray_dir)] @@ -1586,10 +1686,10 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): @staticmethod def _resolve_download_history( - downloadhis: DownloadHistoryOper, - file_path: Path, - bluray_dir: bool = False, - download_hash: Optional[str] = None, + downloadhis: DownloadHistoryOper, + file_path: Path, + bluray_dir: bool = False, + download_hash: Optional[str] = None, ) -> Optional[DownloadHistory]: """ 根据显式 hash、文件路径或种子根目录回查下载历史。 @@ -1624,26 +1724,26 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return None def do_transfer( - self, - fileitem: FileItem, - meta: MetaBase = None, - mediainfo: MediaInfo = None, - target_directory: TransferDirectoryConf = None, - target_storage: Optional[str] = None, - target_path: Path = None, - transfer_type: Optional[str] = None, - scrape: Optional[bool] = None, - library_type_folder: Optional[bool] = None, - library_category_folder: Optional[bool] = None, - season: Optional[int] = None, - epformat: EpisodeFormat = None, - min_filesize: Optional[int] = 0, - downloader: Optional[str] = None, - download_hash: Optional[str] = None, - force: Optional[bool] = False, - background: Optional[bool] = True, - manual: Optional[bool] = False, - continue_callback: Callable = None, + self, + fileitem: FileItem, + meta: MetaBase = None, + mediainfo: MediaInfo = None, + target_directory: TransferDirectoryConf = None, + target_storage: Optional[str] = None, + target_path: Path = None, + transfer_type: Optional[str] = None, + scrape: Optional[bool] = None, + library_type_folder: Optional[bool] = None, + library_category_folder: Optional[bool] = None, + season: Optional[int] = None, + epformat: EpisodeFormat = None, + min_filesize: Optional[int] = 0, + downloader: Optional[str] = None, + download_hash: Optional[str] = None, + force: Optional[bool] = False, + background: Optional[bool] = True, + manual: Optional[bool] = False, + continue_callback: Callable = None, ) -> Tuple[bool, str]: """ 执行一个复杂目录的整理操作 @@ -1690,7 +1790,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): # 汇总错误信息 err_msgs: List[str] = [] - def _filter(file_item: FileItem, is_bluray_dir: bool) -> bool: + def _filter(item: FileItem, is_bluray_dir: bool) -> bool: """ 过滤文件项 @@ -1699,30 +1799,30 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): if continue_callback and not continue_callback(): raise OperationInterrupted() # 有集自定义格式,过滤文件 - if formaterHandler and not formaterHandler.match(file_item.name): + if formaterHandler and not formaterHandler.match(item.name): return False # 过滤后缀和大小(蓝光目录、附加文件不过滤) if ( - not is_bluray_dir - and not self.__is_subtitle_file(file_item) - and not self.__is_audio_file(file_item) + not is_bluray_dir + and not self.__is_subtitle_file(item) + and not self.__is_audio_file(item) ): - if not self.__is_media_file(file_item): + if not self.__is_media_file(item): return False - if not self.__is_allow_filesize(file_item, min_filesize): + if not self.__is_allow_filesize(item, min_filesize): return False # 回收站及隐藏的文件不处理 if ( - file_item.path.find("/@Recycle/") != -1 - or file_item.path.find("/#recycle/") != -1 - or file_item.path.find("/.") != -1 - or file_item.path.find("/@eaDir") != -1 + item.path.find("/@Recycle/") != -1 + or item.path.find("/#recycle/") != -1 + or item.path.find("/.") != -1 + or item.path.find("/@eaDir") != -1 ): - logger.debug(f"{file_item.path} 是回收站或隐藏的文件") + logger.debug(f"{item.path} 是回收站或隐藏的文件") return False # 整理屏蔽词不处理 if self._is_blocked_by_exclude_words( - file_item.path, transfer_exclude_words + item.path, transfer_exclude_words ): return False return True @@ -1929,11 +2029,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return all_success, error_msg def remote_transfer( - self, - arg_str: str, - channel: MessageChannel, - userid: Union[str, int] = None, - source: Optional[str] = None, + self, + arg_str: str, + channel: MessageChannel, + userid: Union[str, int] = None, + source: Optional[str] = None, ): """ 远程重新整理,参数 历史记录ID TMDBID|类型 @@ -1945,7 +2045,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): channel=channel, source=source, title="请输入正确的命令格式:/redo [id] 或 /redo [id] [tmdbid/豆瓣id]|[类型]," - "[id] 为整理记录编号", + "[id] 为整理记录编号", userid=userid, ) ) @@ -2005,7 +2105,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): @staticmethod def build_failed_transfer_buttons( - history_id: Optional[int], + history_id: Optional[int], ) -> Optional[List[List[dict]]]: """ 构建整理失败通知的操作按钮。 @@ -2029,7 +2129,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return self.__re_transfer(logid=history_id) def __re_transfer( - self, logid: int, mtype: MediaType = None, mediaid: Optional[str] = None + self, logid: int, mtype: MediaType = None, mediaid: Optional[str] = None ) -> Tuple[bool, str]: """ 根据历史记录,重新识别整理,只支持简单条件 @@ -2088,25 +2188,25 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return True, "" def manual_transfer( - self, - fileitem: FileItem, - target_storage: Optional[str] = None, - target_path: Path = None, - tmdbid: Optional[int] = None, - doubanid: Optional[str] = None, - mtype: MediaType = None, - season: Optional[int] = None, - episode_group: Optional[str] = None, - transfer_type: Optional[str] = None, - epformat: EpisodeFormat = None, - min_filesize: Optional[int] = 0, - scrape: Optional[bool] = None, - library_type_folder: Optional[bool] = None, - library_category_folder: Optional[bool] = None, - force: Optional[bool] = False, - background: Optional[bool] = False, - downloader: Optional[str] = None, - download_hash: Optional[str] = None, + self, + fileitem: FileItem, + target_storage: Optional[str] = None, + target_path: Path = None, + tmdbid: Optional[int] = None, + doubanid: Optional[str] = None, + mtype: MediaType = None, + season: Optional[int] = None, + episode_group: Optional[str] = None, + transfer_type: Optional[str] = None, + epformat: EpisodeFormat = None, + min_filesize: Optional[int] = 0, + scrape: Optional[bool] = None, + library_type_folder: Optional[bool] = None, + library_category_folder: Optional[bool] = None, + force: Optional[bool] = False, + background: Optional[bool] = False, + downloader: Optional[str] = None, + download_hash: Optional[str] = None, ) -> Tuple[bool, Union[str, list]]: """ 手动整理,支持复杂条件,带进度显示 @@ -2194,12 +2294,12 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return state, errmsg def send_transfer_message( - self, - meta: MetaBase, - mediainfo: MediaInfo, - transferinfo: TransferInfo, - season_episode: Optional[str] = None, - username: Optional[str] = None, + self, + meta: MetaBase, + mediainfo: MediaInfo, + transferinfo: TransferInfo, + season_episode: Optional[str] = None, + username: Optional[str] = None, ): """ 发送入库成功的消息 @@ -2237,7 +2337,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): return False def _can_delete_torrent( - self, download_hash: str, downloader: str, transfer_exclude_words + self, download_hash: str, downloader: str, transfer_exclude_words ) -> bool: """ 检查是否可以删除种子文件 @@ -2270,11 +2370,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton): file_path = save_path / file.name # 如果存在未被屏蔽的媒体文件,则不删除种子 if ( - file_path.suffix in self._allowed_exts - and not self._is_blocked_by_exclude_words( - file_path.as_posix(), transfer_exclude_words - ) - and file_path.exists() + file_path.suffix in self._allowed_exts + and not self._is_blocked_by_exclude_words( + file_path.as_posix(), transfer_exclude_words + ) + and file_path.exists() ): return False diff --git a/app/core/config.py b/app/core/config.py index e62247ca..da17c4c2 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -506,7 +506,7 @@ class ConfigModel(BaseModel): # LLM模型名称 LLM_MODEL: str = "deepseek-chat" # 思考模式/深度配置:off/auto/minimal/low/medium/high/max/xhigh - LLM_THINKING_LEVEL: Optional[str] = 'off' + LLM_THINKING_LEVEL: Optional[str] = "off" # LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型 LLM_SUPPORT_IMAGE_INPUT: bool = True # LLM是否支持音频输入输出,开启后才会启用语音转写与语音回复 diff --git a/app/schemas/history.py b/app/schemas/history.py index f6d489ed..243bb77d 100644 --- a/app/schemas/history.py +++ b/app/schemas/history.py @@ -1,6 +1,6 @@ from typing import Optional, Any -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field class DownloadHistory(BaseModel): @@ -97,3 +97,7 @@ class TransferHistory(BaseModel): date: Optional[str] = None model_config = ConfigDict(from_attributes=True) + + +class BatchTransferHistoryRedoRequest(BaseModel): + history_ids: list[int] = Field(default_factory=list) diff --git a/app/startup/agent_initializer.py b/app/startup/agent_initializer.py index 81033405..83d951e5 100644 --- a/app/startup/agent_initializer.py +++ b/app/startup/agent_initializer.py @@ -10,10 +10,10 @@ class AgentInitializer: """ AI智能体初始化器 """ - + def __init__(self): self._initialized = False - + async def initialize(self) -> bool: """ 初始化AI智能体管理器 @@ -22,16 +22,16 @@ class AgentInitializer: if not settings.AI_AGENT_ENABLE: logger.info("AI智能体功能未启用") return True - + await agent_manager.initialize() self._initialized = True logger.info("AI智能体管理器初始化成功") return True - + except Exception as e: logger.error(f"AI智能体管理器初始化失败: {e}") return False - + async def cleanup(self) -> None: """ 清理AI智能体管理器 @@ -39,11 +39,10 @@ class AgentInitializer: try: if not self._initialized: return - await agent_manager.close() self._initialized = False logger.info("AI智能体管理器已关闭") - + except Exception as e: logger.error(f"关闭AI智能体管理器时发生错误: {e}") @@ -60,7 +59,7 @@ def init_agent(): if not settings.AI_AGENT_ENABLE: logger.info("AI智能体功能未启用") return True - + # 在新的事件循环中初始化AI智能体管理器 def run_init(): loop = asyncio.new_event_loop() @@ -77,13 +76,13 @@ def init_agent(): return False finally: loop.close() - + # 在后台线程中初始化 init_thread = threading.Thread(target=run_init, daemon=True) init_thread.start() - + return True - + except Exception as e: logger.error(f"初始化AI智能体时发生错误: {e}") return False diff --git a/tests/test_agent_prompt_style.py b/tests/test_agent_prompt_style.py index 01bcddb4..c1b1d0b0 100644 --- a/tests/test_agent_prompt_style.py +++ b/tests/test_agent_prompt_style.py @@ -77,6 +77,21 @@ class TestAgentPromptStyle(unittest.TestCase): self.assertIn("Total failed records: 1", message) self.assertIn("history_id=7", message) + def test_render_batch_manual_transfer_redo_message(self): + message = prompt_manager.render_system_task_message( + "batch_manual_transfer_redo", + template_context={ + "history_ids_csv": "7, 8", + "history_count": 2, + "records_context": "Record #7:\n- Source path: /downloads/a.mkv", + }, + ) + + self.assertIn("[System Task - Batch Manual Transfer Re-Organize]", message) + self.assertIn("History IDs: 7, 8", message) + self.assertIn("Total records: 2", message) + self.assertIn("Record #7:", message) + def test_missing_system_task_template_context_raises_clear_error(self): with self.assertRaises(PromptConfigError): prompt_manager.render_system_task_message("transfer_failed_retry") diff --git a/tests/test_transfer_failed_retry_buttons.py b/tests/test_transfer_failed_retry_buttons.py index 74d9a611..349f01a2 100644 --- a/tests/test_transfer_failed_retry_buttons.py +++ b/tests/test_transfer_failed_retry_buttons.py @@ -1,6 +1,7 @@ import unittest import sys from types import ModuleType +from types import SimpleNamespace from unittest.mock import patch sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi")) @@ -74,9 +75,33 @@ class TestTransferFailedRetryButtons(unittest.TestCase): def test_transfer_ai_retry_callback_schedules_agent_takeover(self): chain = MessageChain() + history = SimpleNamespace( + id=34, + status=False, + title="Test Show", + type="电视剧", + category=None, + year="2024", + seasons="S01", + episodes="E01", + src="/downloads/Test.Show.S01E01.mkv", + src_storage="local", + src_fileitem={"path": "/downloads/Test.Show.S01E01.mkv"}, + dest=None, + dest_storage=None, + mode="copy", + tmdbid=123, + doubanid=None, + errmsg="未识别到媒体信息", + ) with patch.object(settings, "AI_AGENT_ENABLE", True): - with patch("app.chain.message.asyncio.run_coroutine_threadsafe") as run_task: + with patch( + "app.chain.message.TransferHistoryOper" + ) as history_oper_cls, patch( + "app.chain.message.asyncio.run_coroutine_threadsafe" + ) as run_task: + history_oper_cls.return_value.get.return_value = history with patch.object(chain, "post_message") as post_message: chain._handle_callback( text="CALLBACK:transfer_ai_retry_34",