mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-11 09:59:50 +08:00
444 lines
21 KiB
Python
444 lines
21 KiB
Python
import asyncio
|
||
import json
|
||
import uuid
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import httpx
|
||
from fastapi import HTTPException
|
||
|
||
from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream
|
||
from domain.auth import User
|
||
|
||
from .mcp import mcp_client_session, mcp_content_to_text
|
||
from .tools import tool_result_to_content
|
||
from .types import AgentChatRequest, PendingMcpCall
|
||
|
||
|
||
def _normalize_path(path: Optional[str]) -> Optional[str]:
|
||
if not path:
|
||
return None
|
||
value = str(path).strip().replace("\\", "/")
|
||
if not value:
|
||
return None
|
||
if not value.startswith("/"):
|
||
value = "/" + value
|
||
return value.rstrip("/") or "/"
|
||
|
||
|
||
def _build_system_prompt(current_path: Optional[str]) -> str:
|
||
lines = [
|
||
"你是 Foxel 的 AI 助手。",
|
||
"你可以通过 MCP 工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。",
|
||
"",
|
||
"可用工具:",
|
||
"- time:获取服务器当前时间(精确到秒,英文星期),支持 year/month/day/hour/minute/second 偏移。",
|
||
"- web_fetch:抓取网页(HTTP 请求),支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS,返回状态/标题/正文/链接等。",
|
||
"- vfs_list_dir:浏览目录(列出 entries + pagination)。",
|
||
"- vfs_stat:查看文件/目录信息。",
|
||
"- vfs_read_text:读取文本文件内容(不支持二进制)。",
|
||
"- vfs_search:搜索文件(vector/filename)。",
|
||
"- vfs_write_text:写入文本文件内容(覆盖)。",
|
||
"- vfs_mkdir:创建目录。",
|
||
"- vfs_delete:删除文件或目录。",
|
||
"- vfs_move:移动路径。",
|
||
"- vfs_copy:复制路径。",
|
||
"- vfs_rename:重命名路径。",
|
||
"- processors_list:获取可用处理器列表(含 type/name/config_schema/produces_file/supports_directory)。",
|
||
"- processors_run:运行处理器处理文件或目录(会返回 task_id 或 task_ids)。",
|
||
"",
|
||
"规则:",
|
||
"1) 读操作(web_fetch/vfs_list_dir/vfs_stat/vfs_read_text/vfs_search)可直接调用工具。",
|
||
"2) 写/改/删操作(vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run)默认需要用户确认;只有在开启自动执行时才应直接执行。",
|
||
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
|
||
"4) 修改文件内容:先读取(vfs_read_text)→给出改动点→确认后再写入(vfs_write_text)。",
|
||
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
|
||
"6) 回答语言跟随用户;用户用英文则用英文,用户用中文则用中文。回答尽量简洁。",
|
||
]
|
||
if current_path:
|
||
lines.append("")
|
||
lines.append(f"当前文件管理目录:{current_path}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _ensure_mcp_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
|
||
mcp_calls = message.get("mcp_calls")
|
||
if not isinstance(mcp_calls, list):
|
||
return message
|
||
|
||
changed = False
|
||
for idx, call in enumerate(mcp_calls):
|
||
if not isinstance(call, dict):
|
||
continue
|
||
call_id = call.get("id")
|
||
if isinstance(call_id, str) and call_id.strip():
|
||
continue
|
||
call["id"] = f"call_{idx}"
|
||
changed = True
|
||
|
||
if changed:
|
||
message["mcp_calls"] = mcp_calls
|
||
return message
|
||
|
||
|
||
def _extract_pending(mcp_call: Dict[str, Any], requires_confirmation: bool) -> PendingMcpCall:
|
||
arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {}
|
||
return PendingMcpCall(
|
||
id=str(mcp_call.get("id") or ""),
|
||
name=str(mcp_call.get("name") or ""),
|
||
arguments=arguments,
|
||
requires_confirmation=requires_confirmation,
|
||
)
|
||
|
||
|
||
def _find_last_assistant_mcp_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
|
||
for idx in range(len(messages) - 1, -1, -1):
|
||
msg = messages[idx]
|
||
if not isinstance(msg, dict):
|
||
continue
|
||
if msg.get("role") != "assistant":
|
||
continue
|
||
mcp_calls = msg.get("mcp_calls")
|
||
if isinstance(mcp_calls, list) and mcp_calls:
|
||
return idx, msg
|
||
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
|
||
|
||
|
||
def _existing_mcp_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
|
||
ids: set[str] = set()
|
||
for msg in messages:
|
||
if not isinstance(msg, dict):
|
||
continue
|
||
if msg.get("role") != "tool":
|
||
continue
|
||
call_id = msg.get("mcp_call_id")
|
||
if isinstance(call_id, str) and call_id.strip():
|
||
ids.add(call_id)
|
||
return ids
|
||
|
||
|
||
def _tool_requires_confirmation(tool_descriptor: Dict[str, Any]) -> bool:
|
||
meta = tool_descriptor.get("meta") if isinstance(tool_descriptor.get("meta"), dict) else {}
|
||
if "requires_confirmation" in meta:
|
||
return bool(meta.get("requires_confirmation"))
|
||
annotations = tool_descriptor.get("annotations") if isinstance(tool_descriptor.get("annotations"), dict) else {}
|
||
return not bool(annotations.get("readOnlyHint"))
|
||
|
||
|
||
async def _choose_chat_ability() -> str:
|
||
tools_model = await AIProviderService.get_default_model("tools")
|
||
return "tools" if tools_model else "chat"
|
||
|
||
|
||
def _sse(event: str, data: Any) -> bytes:
|
||
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
|
||
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
|
||
|
||
|
||
def _format_exc(exc: BaseException) -> str:
|
||
text = str(exc)
|
||
return text if text else exc.__class__.__name__
|
||
|
||
|
||
async def _list_mcp_tools(session) -> List[Dict[str, Any]]:
|
||
result = await session.list_tools()
|
||
tools: List[Dict[str, Any]] = []
|
||
for item in result.tools:
|
||
annotations = getattr(item, "annotations", None)
|
||
meta = getattr(item, "meta", None)
|
||
tools.append(
|
||
{
|
||
"name": str(getattr(item, "name", "") or ""),
|
||
"description": str(getattr(item, "description", "") or ""),
|
||
"input_schema": getattr(item, "inputSchema", None) or {},
|
||
"annotations": annotations.model_dump(exclude_none=True) if annotations is not None else {},
|
||
"meta": meta if isinstance(meta, dict) else {},
|
||
}
|
||
)
|
||
return tools
|
||
|
||
|
||
async def _execute_mcp_call(session, name: str, arguments: Dict[str, Any]) -> str:
|
||
result = await session.call_tool(name, arguments)
|
||
return mcp_content_to_text(result.content, result.structuredContent)
|
||
|
||
|
||
class AgentService:
|
||
@classmethod
|
||
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
|
||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||
system_prompt = _build_system_prompt(current_path)
|
||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||
new_messages: List[Dict[str, Any]] = []
|
||
pending: List[PendingMcpCall] = []
|
||
|
||
approved_ids = {i for i in (req.approved_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
|
||
rejected_ids = {i for i in (req.rejected_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
|
||
|
||
async with mcp_client_session(user, current_path) as mcp_session:
|
||
tools_schema = await _list_mcp_tools(mcp_session)
|
||
tool_index = {tool["name"]: tool for tool in tools_schema if tool.get("name")}
|
||
|
||
if approved_ids or rejected_ids:
|
||
_, last_call_msg = _find_last_assistant_mcp_calls(internal_messages)
|
||
last_call_msg = _ensure_mcp_call_ids(last_call_msg)
|
||
mcp_calls = last_call_msg.get("mcp_calls") or []
|
||
call_map: Dict[str, Dict[str, Any]] = {
|
||
str(call.get("id")): call
|
||
for call in mcp_calls
|
||
if isinstance(call, dict) and isinstance(call.get("id"), str)
|
||
}
|
||
|
||
existing_ids = _existing_mcp_result_ids(internal_messages)
|
||
for call_id in approved_ids | rejected_ids:
|
||
if call_id in existing_ids:
|
||
continue
|
||
mcp_call = call_map.get(call_id)
|
||
if not mcp_call:
|
||
continue
|
||
name = str(mcp_call.get("name") or "")
|
||
arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {}
|
||
tool_desc = tool_index.get(name)
|
||
|
||
if call_id in rejected_ids:
|
||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||
elif not tool_desc:
|
||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||
else:
|
||
try:
|
||
content = await _execute_mcp_call(mcp_session, name, arguments)
|
||
except Exception as exc: # noqa: BLE001
|
||
content = tool_result_to_content({"error": str(exc)})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
|
||
ability = await _choose_chat_ability()
|
||
|
||
for _ in range(8):
|
||
try:
|
||
assistant = await chat_completion(
|
||
internal_messages,
|
||
ability=ability,
|
||
tools=tools_schema,
|
||
tool_choice="auto",
|
||
timeout=60.0,
|
||
)
|
||
except MissingModelError as exc:
|
||
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
|
||
except httpx.HTTPStatusError as exc:
|
||
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
|
||
except httpx.RequestError as exc:
|
||
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
|
||
|
||
assistant = _ensure_mcp_call_ids(assistant if isinstance(assistant, dict) else {"role": "assistant", "content": ""})
|
||
internal_messages.append(assistant)
|
||
new_messages.append(assistant)
|
||
|
||
mcp_calls = assistant.get("mcp_calls")
|
||
if not isinstance(mcp_calls, list) or not mcp_calls:
|
||
break
|
||
|
||
pending = []
|
||
for call in mcp_calls:
|
||
if not isinstance(call, dict):
|
||
continue
|
||
call_id = str(call.get("id") or "")
|
||
name = str(call.get("name") or "")
|
||
arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
|
||
tool_desc = tool_index.get(name)
|
||
|
||
if not tool_desc:
|
||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
continue
|
||
|
||
if _tool_requires_confirmation(tool_desc) and not req.auto_execute:
|
||
pending.append(_extract_pending(call, True))
|
||
continue
|
||
|
||
try:
|
||
content = await _execute_mcp_call(mcp_session, name, arguments)
|
||
except Exception as exc: # noqa: BLE001
|
||
content = tool_result_to_content({"error": str(exc)})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
|
||
if pending:
|
||
break
|
||
|
||
payload: Dict[str, Any] = {"messages": new_messages}
|
||
if pending:
|
||
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
|
||
return payload
|
||
|
||
@classmethod
|
||
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
|
||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||
system_prompt = _build_system_prompt(current_path)
|
||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||
new_messages: List[Dict[str, Any]] = []
|
||
pending: List[PendingMcpCall] = []
|
||
|
||
approved_ids = {i for i in (req.approved_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
|
||
rejected_ids = {i for i in (req.rejected_mcp_call_ids or []) if isinstance(i, str) and i.strip()}
|
||
|
||
try:
|
||
async with mcp_client_session(user, current_path) as mcp_session:
|
||
tools_schema = await _list_mcp_tools(mcp_session)
|
||
tool_index = {tool["name"]: tool for tool in tools_schema if tool.get("name")}
|
||
|
||
if approved_ids or rejected_ids:
|
||
_, last_call_msg = _find_last_assistant_mcp_calls(internal_messages)
|
||
last_call_msg = _ensure_mcp_call_ids(last_call_msg)
|
||
mcp_calls = last_call_msg.get("mcp_calls") or []
|
||
call_map: Dict[str, Dict[str, Any]] = {
|
||
str(call.get("id")): call
|
||
for call in mcp_calls
|
||
if isinstance(call, dict) and isinstance(call.get("id"), str)
|
||
}
|
||
|
||
existing_ids = _existing_mcp_result_ids(internal_messages)
|
||
for call_id in approved_ids | rejected_ids:
|
||
if call_id in existing_ids:
|
||
continue
|
||
mcp_call = call_map.get(call_id)
|
||
if not mcp_call:
|
||
continue
|
||
|
||
name = str(mcp_call.get("name") or "")
|
||
arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {}
|
||
tool_desc = tool_index.get(name)
|
||
|
||
if call_id in rejected_ids:
|
||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
|
||
continue
|
||
|
||
if not tool_desc:
|
||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
|
||
continue
|
||
|
||
yield _sse("mcp_call_start", {"mcp_call_id": call_id, "name": name})
|
||
try:
|
||
content = await _execute_mcp_call(mcp_session, name, arguments)
|
||
except Exception as exc: # noqa: BLE001
|
||
content = tool_result_to_content({"error": str(exc)})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
|
||
|
||
ability = await _choose_chat_ability()
|
||
|
||
for _ in range(8):
|
||
assistant_event_id = str(uuid.uuid4())
|
||
yield _sse("assistant_start", {"id": assistant_event_id})
|
||
|
||
assistant_message: Dict[str, Any] | None = None
|
||
try:
|
||
async for event in chat_completion_stream(
|
||
internal_messages,
|
||
ability=ability,
|
||
tools=tools_schema,
|
||
tool_choice="auto",
|
||
timeout=60.0,
|
||
):
|
||
event_type = event.get("type")
|
||
if event_type == "delta":
|
||
delta = event.get("delta")
|
||
if isinstance(delta, str) and delta:
|
||
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
|
||
elif event_type == "message":
|
||
msg = event.get("message")
|
||
if isinstance(msg, dict):
|
||
assistant_message = msg
|
||
except MissingModelError as exc:
|
||
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
|
||
except httpx.HTTPStatusError as exc:
|
||
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
|
||
except httpx.RequestError as exc:
|
||
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
|
||
|
||
if not assistant_message:
|
||
assistant_message = {"role": "assistant", "content": ""}
|
||
|
||
assistant_message = _ensure_mcp_call_ids(assistant_message)
|
||
internal_messages.append(assistant_message)
|
||
new_messages.append(assistant_message)
|
||
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
|
||
|
||
mcp_calls = assistant_message.get("mcp_calls")
|
||
if not isinstance(mcp_calls, list) or not mcp_calls:
|
||
break
|
||
|
||
pending = []
|
||
for call in mcp_calls:
|
||
if not isinstance(call, dict):
|
||
continue
|
||
call_id = str(call.get("id") or "")
|
||
name = str(call.get("name") or "")
|
||
arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
|
||
tool_desc = tool_index.get(name)
|
||
|
||
if not tool_desc:
|
||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
|
||
continue
|
||
|
||
if _tool_requires_confirmation(tool_desc) and not req.auto_execute:
|
||
pending.append(_extract_pending(call, True))
|
||
continue
|
||
|
||
yield _sse("mcp_call_start", {"mcp_call_id": call_id, "name": name})
|
||
try:
|
||
content = await _execute_mcp_call(mcp_session, name, arguments)
|
||
except Exception as exc: # noqa: BLE001
|
||
content = tool_result_to_content({"error": str(exc)})
|
||
tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content}
|
||
internal_messages.append(tool_msg)
|
||
new_messages.append(tool_msg)
|
||
yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg})
|
||
|
||
if pending:
|
||
yield _sse("pending", {"pending_mcp_calls": [item.model_dump() for item in pending]})
|
||
break
|
||
|
||
payload: Dict[str, Any] = {"messages": new_messages}
|
||
if pending:
|
||
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
|
||
yield _sse("done", payload)
|
||
|
||
except asyncio.CancelledError:
|
||
return
|
||
except HTTPException as exc:
|
||
detail = exc.detail
|
||
content = detail if isinstance(detail, str) else str(detail)
|
||
if not content.strip():
|
||
content = f"请求失败({exc.status_code})"
|
||
new_messages.append({"role": "assistant", "content": content})
|
||
payload: Dict[str, Any] = {"messages": new_messages}
|
||
if pending:
|
||
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
|
||
yield _sse("done", payload)
|
||
except Exception as exc: # noqa: BLE001
|
||
new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"})
|
||
payload: Dict[str, Any] = {"messages": new_messages}
|
||
if pending:
|
||
payload["pending_mcp_calls"] = [item.model_dump() for item in pending]
|
||
yield _sse("done", payload)
|