Files
Foxel/domain/agent/service.py

444 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import json
import uuid
from typing import Any, Dict, List, Optional, Tuple
import httpx
from fastapi import HTTPException
from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream
from domain.auth import User
from .mcp import mcp_client_session, mcp_content_to_text
from .tools import tool_result_to_content
from .types import AgentChatRequest, PendingMcpCall
def _normalize_path(path: Optional[str]) -> Optional[str]:
if not path:
return None
value = str(path).strip().replace("\\", "/")
if not value:
return None
if not value.startswith("/"):
value = "/" + value
return value.rstrip("/") or "/"
def _build_system_prompt(current_path: Optional[str]) -> str:
lines = [
"你是 Foxel 的 AI 助手。",
"你可以通过 MCP 工具对文件/目录进行查询、读写、移动、复制、删除以及运行处理器processor",
"",
"可用工具:",
"- time获取服务器当前时间精确到秒英文星期支持 year/month/day/hour/minute/second 偏移。",
"- web_fetch抓取网页HTTP 请求),支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS返回状态/标题/正文/链接等。",
"- vfs_list_dir浏览目录列出 entries + pagination",
"- vfs_stat查看文件/目录信息。",
"- vfs_read_text读取文本文件内容不支持二进制",
"- vfs_search搜索文件vector/filename",
"- vfs_write_text写入文本文件内容覆盖",
"- vfs_mkdir创建目录。",
"- vfs_delete删除文件或目录。",
"- vfs_move移动路径。",
"- vfs_copy复制路径。",
"- vfs_rename重命名路径。",
"- processors_list获取可用处理器列表含 type/name/config_schema/produces_file/supports_directory",
"- processors_run运行处理器处理文件或目录会返回 task_id 或 task_ids",
"",
"规则:",
"1) 读操作web_fetch/vfs_list_dir/vfs_stat/vfs_read_text/vfs_search可直接调用工具。",
"2) 写/改/删操作vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run默认需要用户确认只有在开启自动执行时才应直接执行。",
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
"4) 修改文件内容先读取vfs_read_text→给出改动点→确认后再写入vfs_write_text",
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
"6) 回答语言跟随用户;用户用英文则用英文,用户用中文则用中文。回答尽量简洁。",
]
if current_path:
lines.append("")
lines.append(f"当前文件管理目录:{current_path}")
return "\n".join(lines)
def _ensure_mcp_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
mcp_calls = message.get("mcp_calls")
if not isinstance(mcp_calls, list):
return message
changed = False
for idx, call in enumerate(mcp_calls):
if not isinstance(call, dict):
continue
call_id = call.get("id")
if isinstance(call_id, str) and call_id.strip():
continue
call["id"] = f"call_{idx}"
changed = True
if changed:
message["mcp_calls"] = mcp_calls
return message
def _extract_pending(mcp_call: Dict[str, Any], requires_confirmation: bool) -> PendingMcpCall:
arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {}
return PendingMcpCall(
id=str(mcp_call.get("id") or ""),
name=str(mcp_call.get("name") or ""),
arguments=arguments,
requires_confirmation=requires_confirmation,
)
def _find_last_assistant_mcp_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if not isinstance(msg, dict):
continue
if msg.get("role") != "assistant":
continue
mcp_calls = msg.get("mcp_calls")
if isinstance(mcp_calls, list) and mcp_calls:
return idx, msg
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
def _existing_mcp_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
ids: set[str] = set()
for msg in messages:
if not isinstance(msg, dict):
continue
if msg.get("role") != "tool":
continue
call_id = msg.get("mcp_call_id")
if isinstance(call_id, str) and call_id.strip():
ids.add(call_id)
return ids
def _tool_requires_confirmation(tool_descriptor: Dict[str, Any]) -> bool:
meta = tool_descriptor.get("meta") if isinstance(tool_descriptor.get("meta"), dict) else {}
if "requires_confirmation" in meta:
return bool(meta.get("requires_confirmation"))
annotations = tool_descriptor.get("annotations") if isinstance(tool_descriptor.get("annotations"), dict) else {}
return not bool(annotations.get("readOnlyHint"))
async def _choose_chat_ability() -> str:
tools_model = await AIProviderService.get_default_model("tools")
return "tools" if tools_model else "chat"
def _sse(event: str, data: Any) -> bytes:
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
def _format_exc(exc: BaseException) -> str:
text = str(exc)
return text if text else exc.__class__.__name__
async def _list_mcp_tools(session) -> List[Dict[str, Any]]:
result = await session.list_tools()
tools: List[Dict[str, Any]] = []
for item in result.tools:
annotations = getattr(item, "annotations", None)
meta = getattr(item, "meta", None)
tools.append(
{
"name": str(getattr(item, "name", "") or ""),
"description": str(getattr(item, "description", "") or ""),
"input_schema": getattr(item, "inputSchema", None) or {},
"annotations": annotations.model_dump(exclude_none=True) if annotations is not None else {},
"meta": meta if isinstance(meta, dict) else {},
}
)
return tools
async def _execute_mcp_call(session, name: str, arguments: Dict[str, Any]) -> str:
result = await session.call_tool(name, arguments)
return mcp_content_to_text(result.content, result.structuredContent)
class AgentService:
@classmethod
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
history: List[Dict[str, Any]] = list(req.messages or [])
current_path = _normalize_path(req.context.current_path if req.context else None)
system_prompt = _build_system_prompt(current_path)
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
new_messages: List[Dict[str, Any]] = []
pending: List[PendingMcpCall] = []
approved_ids = {i for i in (req.approved_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
rejected_ids = {i for i in (req.rejected_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
async with mcp_client_session(user, current_path) as mcp_session:
tools_schema = await _list_mcp_tools(mcp_session)
tool_index = {tool["name"]: tool for tool in tools_schema if tool.get("name")}
if approved_ids or rejected_ids:
_, last_call_msg = _find_last_assistant_mcp_calls(internal_messages)
last_call_msg = _ensure_mcp_call_ids(last_call_msg)
mcp_calls = last_call_msg.get("mcp_calls") or []
call_map: Dict[str, Dict[str, Any]] = {
str(call.get("id")): call
for call in mcp_calls
if isinstance(call, dict) and isinstance(call.get("id"), str)
}
existing_ids = _existing_mcp_result_ids(internal_messages)
for call_id in approved_ids | rejected_ids:
if call_id in existing_ids:
continue
mcp_call = call_map.get(call_id)
if not mcp_call:
continue
name = str(mcp_call.get("name") or "")
arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {}
tool_desc = tool_index.get(name)
if call_id in rejected_ids:
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
elif not tool_desc:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
else:
try:
content = await _execute_mcp_call(mcp_session, name, arguments)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
ability = await _choose_chat_ability()
for _ in range(8):
try:
assistant = await chat_completion(
internal_messages,
ability=ability,
tools=tools_schema,
tool_choice="auto",
timeout=60.0,
)
except MissingModelError as exc:
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
except httpx.HTTPStatusError as exc:
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
assistant = _ensure_mcp_call_ids(assistant if isinstance(assistant, dict) else {"role": "assistant", "content": ""})
internal_messages.append(assistant)
new_messages.append(assistant)
mcp_calls = assistant.get("mcp_calls")
if not isinstance(mcp_calls, list) or not mcp_calls:
break
pending = []
for call in mcp_calls:
if not isinstance(call, dict):
continue
call_id = str(call.get("id") or "")
name = str(call.get("name") or "")
arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
tool_desc = tool_index.get(name)
if not tool_desc:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
continue
if _tool_requires_confirmation(tool_desc) and not req.auto_execute:
pending.append(_extract_pending(call, True))
continue
try:
content = await _execute_mcp_call(mcp_session, name, arguments)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
if pending:
break
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
return payload
@classmethod
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
history: List[Dict[str, Any]] = list(req.messages or [])
current_path = _normalize_path(req.context.current_path if req.context else None)
system_prompt = _build_system_prompt(current_path)
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
new_messages: List[Dict[str, Any]] = []
pending: List[PendingMcpCall] = []
approved_ids = {i for i in (req.approved_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
rejected_ids = {i for i in (req.rejected_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
try:
async with mcp_client_session(user, current_path) as mcp_session:
tools_schema = await _list_mcp_tools(mcp_session)
tool_index = {tool["name"]: tool for tool in tools_schema if tool.get("name")}
if approved_ids or rejected_ids:
_, last_call_msg = _find_last_assistant_mcp_calls(internal_messages)
last_call_msg = _ensure_mcp_call_ids(last_call_msg)
mcp_calls = last_call_msg.get("mcp_calls") or []
call_map: Dict[str, Dict[str, Any]] = {
str(call.get("id")): call
for call in mcp_calls
if isinstance(call, dict) and isinstance(call.get("id"), str)
}
existing_ids = _existing_mcp_result_ids(internal_messages)
for call_id in approved_ids | rejected_ids:
if call_id in existing_ids:
continue
mcp_call = call_map.get(call_id)
if not mcp_call:
continue
name = str(mcp_call.get("name") or "")
arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {}
tool_desc = tool_index.get(name)
if call_id in rejected_ids:
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
continue
if not tool_desc:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
continue
yield _sse("mcp_call_start", {"mcp_call_id": call_id, "name": name})
try:
content = await _execute_mcp_call(mcp_session, name, arguments)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
ability = await _choose_chat_ability()
for _ in range(8):
assistant_event_id = str(uuid.uuid4())
yield _sse("assistant_start", {"id": assistant_event_id})
assistant_message: Dict[str, Any] | None = None
try:
async for event in chat_completion_stream(
internal_messages,
ability=ability,
tools=tools_schema,
tool_choice="auto",
timeout=60.0,
):
event_type = event.get("type")
if event_type == "delta":
delta = event.get("delta")
if isinstance(delta, str) and delta:
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
elif event_type == "message":
msg = event.get("message")
if isinstance(msg, dict):
assistant_message = msg
except MissingModelError as exc:
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
except httpx.HTTPStatusError as exc:
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
if not assistant_message:
assistant_message = {"role": "assistant", "content": ""}
assistant_message = _ensure_mcp_call_ids(assistant_message)
internal_messages.append(assistant_message)
new_messages.append(assistant_message)
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
mcp_calls = assistant_message.get("mcp_calls")
if not isinstance(mcp_calls, list) or not mcp_calls:
break
pending = []
for call in mcp_calls:
if not isinstance(call, dict):
continue
call_id = str(call.get("id") or "")
name = str(call.get("name") or "")
arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
tool_desc = tool_index.get(name)
if not tool_desc:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
continue
if _tool_requires_confirmation(tool_desc) and not req.auto_execute:
pending.append(_extract_pending(call, True))
continue
yield _sse("mcp_call_start", {"mcp_call_id": call_id, "name": name})
try:
content = await _execute_mcp_call(mcp_session, name, arguments)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
if pending:
yield _sse("pending", {"pending_mcp_calls": [item.model_dump() for item in pending]})
break
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
yield _sse("done", payload)
except asyncio.CancelledError:
return
except HTTPException as exc:
detail = exc.detail
content = detail if isinstance(detail, str) else str(detail)
if not content.strip():
content = f"请求失败({exc.status_code})"
new_messages.append({"role": "assistant", "content": content})
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
yield _sse("done", payload)
except Exception as exc: # noqa: BLE001
new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"})
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
yield _sse("done", payload)