mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-30 12:11:49 +08:00
fix: improve web agent stream recovery
This commit is contained in:
@@ -1743,6 +1743,12 @@ class AgentManager:
|
||||
queue = self._session_queues.get(session_id)
|
||||
return bool(queue and not queue.empty())
|
||||
|
||||
def is_session_busy(self, session_id: str) -> bool:
|
||||
"""
|
||||
查询会话是否仍有正在执行或排队的任务。
|
||||
"""
|
||||
return self._is_session_busy(session_id)
|
||||
|
||||
def _expired_idle_sessions(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
收集已经超过空闲时间且当前不忙的会话。
|
||||
|
||||
@@ -52,6 +52,7 @@ _WEB_AGENT_FILE_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
_WEB_AGENT_NOTICE_QUEUES: dict[str, list[Queue[schemas.Notification]]] = {}
|
||||
_WEB_AGENT_NOTICE_LOCK = Lock()
|
||||
_WEB_AGENT_NOTICE_LISTENER_REGISTERED = False
|
||||
_WEB_AGENT_BACKGROUND_TASKS: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
class _WebAgentStreamingHandler(StreamingHandler):
|
||||
@@ -254,7 +255,7 @@ def _apply_web_agent_display_event(event: dict, assistant_message: dict) -> None
|
||||
assistant_message["tools"].append(
|
||||
{
|
||||
"id": f"tool-{uuid.uuid4().hex}",
|
||||
"message": str(event.get("message") or "").replace("=>", "", 1).strip(),
|
||||
"message": str(event.get("message") or "").strip(),
|
||||
"status": "running",
|
||||
}
|
||||
)
|
||||
@@ -1259,24 +1260,21 @@ def _split_web_agent_output(text: str) -> list[dict]:
|
||||
events = []
|
||||
|
||||
def append_text(content: str) -> None:
|
||||
"""将工具汇总行从普通文本中拆出,便于前端独立展示。"""
|
||||
"""将工具汇总行从普通文本中拆出,保留与消息渠道一致的展示文案。"""
|
||||
if not content:
|
||||
return
|
||||
lines = content.splitlines(keepends=True)
|
||||
buffer = ""
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
if (
|
||||
stripped_line.startswith("(")
|
||||
and stripped_line.endswith(")")
|
||||
):
|
||||
if stripped_line.startswith("(") and stripped_line.endswith(")"):
|
||||
if buffer:
|
||||
events.append({"type": "delta", "content": buffer})
|
||||
buffer = ""
|
||||
events.append(
|
||||
{
|
||||
"type": "tool",
|
||||
"message": stripped_line.strip("()"),
|
||||
"message": stripped_line,
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -1305,7 +1303,7 @@ def _split_web_agent_output(text: str) -> list[dict]:
|
||||
remaining = after_marker[line_end:].lstrip("\n")
|
||||
|
||||
if message:
|
||||
events.append({"type": "tool", "message": message})
|
||||
events.append({"type": "tool", "message": f"{marker}{message}"})
|
||||
|
||||
return events
|
||||
|
||||
@@ -1469,10 +1467,28 @@ async def get_agent_chat_session(
|
||||
:param db: 异步数据库会话
|
||||
:return: 会话详情
|
||||
"""
|
||||
chat = await _get_accessible_agent_chat(AgentChatOper(db), session_id, current_user)
|
||||
oper = AgentChatOper(db)
|
||||
chat = await _get_accessible_agent_chat(oper, session_id, current_user)
|
||||
server_session_id = session_id
|
||||
if not chat:
|
||||
server_session_id = _build_web_agent_session_id(current_user, session_id)
|
||||
if server_session_id != session_id:
|
||||
chat = await _get_accessible_agent_chat(oper, server_session_id, current_user)
|
||||
if not chat:
|
||||
if agent_manager.is_session_busy(server_session_id):
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
"session_id": server_session_id,
|
||||
"client_session_id": session_id,
|
||||
"messages": [],
|
||||
"is_processing": True,
|
||||
},
|
||||
)
|
||||
return schemas.Response(success=False, message="会话不存在或无权访问")
|
||||
return schemas.Response(success=True, data=AgentChatOper.to_detail(chat))
|
||||
data = AgentChatOper.to_detail(chat)
|
||||
data["is_processing"] = agent_manager.is_session_busy(chat.session_id)
|
||||
return schemas.Response(success=True, data=data)
|
||||
|
||||
|
||||
@router.put("/sessions/{session_id}/display", summary="保存 Agent 展示会话", response_model=schemas.Response)
|
||||
@@ -1535,6 +1551,37 @@ async def delete_agent_chat_session(
|
||||
return schemas.Response(success=deleted, message="删除成功" if deleted else "删除失败")
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/stop", summary="停止 Web 智能助手当前任务", response_model=schemas.Response)
|
||||
async def stop_web_agent_session_task(
|
||||
session_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
停止当前 Web 智能助手会话正在执行的任务。
|
||||
|
||||
:param session_id: Agent 会话 ID
|
||||
:param current_user: 当前登录用户
|
||||
:param db: 异步数据库会话
|
||||
:return: 停止结果
|
||||
"""
|
||||
server_session_id = _build_web_agent_session_id(current_user, session_id)
|
||||
chat = await _get_accessible_agent_chat(
|
||||
AgentChatOper(db), server_session_id, current_user
|
||||
)
|
||||
if not chat and server_session_id != session_id:
|
||||
chat = await _get_accessible_agent_chat(AgentChatOper(db), session_id, current_user)
|
||||
if chat and not _can_access_agent_chat(chat, current_user):
|
||||
return schemas.Response(success=False, message="会话不存在或无权访问")
|
||||
|
||||
stopped = await agent_manager.stop_current_task(server_session_id)
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={"stopped": stopped},
|
||||
message="已停止" if stopped else "当前没有正在执行的任务",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stream", summary="Web智能助手流式对话")
|
||||
async def web_agent_stream(
|
||||
payload: schemas.AgentWebChatRequest,
|
||||
@@ -1772,28 +1819,34 @@ async def web_agent_stream(
|
||||
await event_queue.put(done_event)
|
||||
|
||||
task = asyncio.create_task(run_agent())
|
||||
_WEB_AGENT_BACKGROUND_TASKS.add(task)
|
||||
task.add_done_callback(_WEB_AGENT_BACKGROUND_TASKS.discard)
|
||||
try:
|
||||
yield _build_web_agent_sse(
|
||||
"start",
|
||||
{"session_id": session_id},
|
||||
)
|
||||
disconnected = False
|
||||
while not global_vars.is_system_stopped:
|
||||
if await request.is_disconnected():
|
||||
disconnected = True
|
||||
break
|
||||
event = await event_queue.get()
|
||||
yield _build_web_agent_sse(event.pop("type"), event)
|
||||
if task.done() and event_queue.empty():
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
disconnected = True
|
||||
return
|
||||
finally:
|
||||
if not task.done():
|
||||
if not task.done() and not disconnected:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# WebAgent 会话由 AgentManager 统一管理,空闲清理或 /clear_session 时释放。
|
||||
# 客户端退到后台导致 SSE 断开时,保留后台 Agent 继续执行;完成后会保存展示快照,
|
||||
# 前端恢复可见时可通过会话详情接口拉取最终状态。
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
@@ -37,22 +37,22 @@ from app.schemas.types import EventType, MessageChannel, NotificationType
|
||||
|
||||
|
||||
def test_split_web_agent_output_extracts_verbose_tool_message():
|
||||
"""应将啰嗦模式工具提示拆成独立工具事件。"""
|
||||
"""应将啰嗦模式工具提示拆成独立工具事件,并保留渠道展示文案。"""
|
||||
events = _split_web_agent_output("准备查询。\n\n⚙️ => 查询站点\n\n已完成")
|
||||
|
||||
assert events == [
|
||||
{"type": "delta", "content": "准备查询。\n\n"},
|
||||
{"type": "tool", "message": "查询站点"},
|
||||
{"type": "tool", "message": "⚙️ => 查询站点"},
|
||||
{"type": "delta", "content": "已完成"},
|
||||
]
|
||||
|
||||
|
||||
def test_split_web_agent_output_extracts_summary_tool_message():
|
||||
"""应将非啰嗦模式工具汇总行拆成独立工具事件。"""
|
||||
"""应将非啰嗦模式工具汇总行拆成独立工具事件,并保留渠道展示文案。"""
|
||||
events = _split_web_agent_output("(查询了 2 次数据)\n\n这里是结果")
|
||||
|
||||
assert events == [
|
||||
{"type": "tool", "message": "查询了 2 次数据"},
|
||||
{"type": "tool", "message": "(查询了 2 次数据)"},
|
||||
{"type": "delta", "content": "\n这里是结果"},
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user