diff --git a/domain/agent/__init__.py b/domain/agent/__init__.py index dbdebaf..7b794cf 100644 --- a/domain/agent/__init__.py +++ b/domain/agent/__init__.py @@ -1,9 +1,10 @@ from .service import AgentService -from .types import AgentChatContext, AgentChatRequest, PendingToolCall +from .types import AgentChatContext, AgentChatRequest, McpCall, PendingMcpCall __all__ = [ "AgentService", "AgentChatContext", "AgentChatRequest", - "PendingToolCall", + "McpCall", + "PendingMcpCall", ] diff --git a/domain/agent/api.py b/domain/agent/api.py index e8ae5cd..d4d8c7c 100644 --- a/domain/agent/api.py +++ b/domain/agent/api.py @@ -14,7 +14,7 @@ router = APIRouter(prefix="/api/agent", tags=["agent"]) @router.post("/chat") -@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"]) +@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute", "approved_mcp_call_ids", "rejected_mcp_call_ids"]) async def chat( request: Request, payload: AgentChatRequest, @@ -25,7 +25,7 @@ async def chat( @router.post("/chat/stream") -@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute"]) +@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute", "approved_mcp_call_ids", "rejected_mcp_call_ids"]) async def chat_stream( request: Request, payload: AgentChatRequest, diff --git a/domain/agent/mcp.py b/domain/agent/mcp.py new file mode 100644 index 0000000..6fed6d1 --- /dev/null +++ b/domain/agent/mcp.py @@ -0,0 +1,334 @@ +import inspect +import json +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import Annotated, Any, Literal +from urllib.parse import quote, unquote + +import httpx +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.server.auth.provider import AccessToken +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.fastmcp.server import AuthSettings +from mcp.types import ToolAnnotations +from pydantic import Field + +from domain.auth import AuthService, User +from domain.processors import ProcessorService + +from .tools import get_tool, mcp_tool_descriptors +from .tools.base import McpToolDescriptor, normalize_tool_result, tool_result_to_content + +INTERNAL_MCP_BASE_URL = "http://127.0.0.1:8000/" +CURRENT_PATH_HEADER = "x-foxel-current-path" + + +def _normalize_path(path: str | None) -> str | None: + if not path: + return None + value = str(path).strip().replace("\\", "/") + if not value: + return None + if not value.startswith("/"): + value = "/" + value + return value.rstrip("/") or "/" + + +def _header_current_path(ctx: Context | None) -> str | None: + request = ctx.request_context.request if ctx and ctx.request_context else None + if request is None: + return None + return _normalize_path(request.headers.get(CURRENT_PATH_HEADER)) + + +def _field_annotation(schema: dict[str, Any], required: bool) -> tuple[Any, Any]: + raw_type = schema.get("type") + enum_values = schema.get("enum") + description = str(schema.get("description") or "").strip() or None + default = schema.get("default", inspect.Parameter.empty if required else None) + + annotation: Any + if isinstance(enum_values, list) and enum_values: + annotation = Literal.__getitem__(tuple(enum_values)) + elif raw_type == "string": + annotation = str + elif raw_type == "integer": + annotation = int + elif raw_type == "number": + annotation = float + elif raw_type == "boolean": + annotation = bool + elif raw_type == "array": + annotation = list[Any] + elif raw_type == "object": + annotation = dict[str, Any] + else: + annotation = Any + + if not required and default is None: + annotation = annotation | None + + if description: + annotation = Annotated[annotation, Field(description=description)] + return annotation, default + + +def _build_tool_signature(descriptor: McpToolDescriptor) -> inspect.Signature: + schema = descriptor.input_schema if isinstance(descriptor.input_schema, dict) else {} + properties = schema.get("properties") if isinstance(schema.get("properties"), dict) else {} + required = set(schema.get("required") or []) + parameters: list[inspect.Parameter] = [] + for key, value in properties.items(): + prop_schema = value if isinstance(value, dict) else {} + annotation, default = _field_annotation(prop_schema, key in required) + parameters.append( + inspect.Parameter( + str(key), + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, + annotation=annotation, + ) + ) + return inspect.Signature(parameters=parameters, return_annotation=dict[str, Any]) + + +def _build_tool_wrapper(descriptor: McpToolDescriptor): + async def wrapper(**kwargs: Any) -> dict[str, Any]: + spec = get_tool(descriptor.name) + if not spec: + return normalize_tool_result({"error": f"unknown_tool: {descriptor.name}"}) + try: + result = await spec.handler(kwargs) + return normalize_tool_result(result) + except Exception as exc: # noqa: BLE001 + return normalize_tool_result({"error": str(exc)}) + + wrapper.__name__ = descriptor.name + wrapper.__doc__ = descriptor.description + wrapper.__signature__ = _build_tool_signature(descriptor) + return wrapper + + +class FoxelMcpTokenVerifier: + async def verify_token(self, token: str) -> AccessToken | None: + try: + user = await AuthService.get_current_active_user(await AuthService.get_current_user(token)) + except Exception: # noqa: BLE001 + return None + return AccessToken(token=token, client_id=user.username, scopes=[]) + + +MCP_SERVER = FastMCP( + name="Foxel MCP", + instructions="Foxel 内置 MCP 服务,提供文件系统、网页抓取、时间与处理器相关能力。", + streamable_http_path="/", + token_verifier=FoxelMcpTokenVerifier(), + auth=AuthSettings( + issuer_url="http://127.0.0.1:8000", + resource_server_url=None, + required_scopes=[], + ), +) + + +for descriptor in mcp_tool_descriptors(): + MCP_SERVER.add_tool( + _build_tool_wrapper(descriptor), + name=descriptor.name, + description=descriptor.description, + annotations=ToolAnnotations.model_validate(descriptor.annotations), + meta=descriptor.meta, + structured_output=False, + ) + + +@MCP_SERVER.resource( + "foxel://context/current-path", + name="current_path", + title="Current Path", + description="返回当前请求上下文里的文件管理目录。", + mime_type="application/json", +) +def current_path_resource() -> dict[str, Any]: + return {"current_path": None} + + +@MCP_SERVER.resource( + "foxel://policy/tool-confirmation", + name="tool_confirmation_policy", + title="Tool Confirmation Policy", + description="返回 Foxel agent 对工具审批的策略。", + mime_type="application/json", +) +def tool_confirmation_policy_resource() -> dict[str, Any]: + return { + "read_tools": [tool.name for tool in mcp_tool_descriptors() if not tool.requires_confirmation], + "write_tools": [tool.name for tool in mcp_tool_descriptors() if tool.requires_confirmation], + "rule": "直接调用 MCP tool 时不额外审批;通过 agent 代表用户执行写操作时需要审批。", + } + + +@MCP_SERVER.resource( + "foxel://processors/index", + name="processors_index", + title="Processors Index", + description="返回当前可用处理器列表。", + mime_type="application/json", +) +def processors_index_resource() -> dict[str, Any]: + return {"processors": ProcessorService.list_processors()} + + +async def _tool_resource(tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]: + spec = get_tool(tool_name) + if not spec: + return normalize_tool_result({"error": f"unknown_tool: {tool_name}"}) + try: + result = await spec.handler(arguments) + return normalize_tool_result(result) + except Exception as exc: # noqa: BLE001 + return normalize_tool_result({"error": str(exc)}) + + +@MCP_SERVER.resource( + "foxel://vfs/stat/{path}", + name="vfs_stat_resource", + title="VFS Stat", + description="读取指定路径的文件或目录元信息;path 需要 URL 编码。", + mime_type="application/json", +) +async def vfs_stat_resource(path: str) -> dict[str, Any]: + return await _tool_resource("vfs_stat", {"path": "/" + unquote(path).lstrip("/")}) + + +@MCP_SERVER.resource( + "foxel://vfs/text/{path}", + name="vfs_text_resource", + title="VFS Text", + description="读取文本文件内容;path 需要 URL 编码。", + mime_type="application/json", +) +async def vfs_text_resource(path: str) -> dict[str, Any]: + return await _tool_resource("vfs_read_text", {"path": "/" + unquote(path).lstrip("/")}) + + +@MCP_SERVER.resource( + "foxel://vfs/dir/{path}", + name="vfs_dir_resource", + title="VFS Directory", + description="列出目录内容;path 需要 URL 编码。", + mime_type="application/json", +) +async def vfs_dir_resource(path: str) -> dict[str, Any]: + return await _tool_resource("vfs_list_dir", {"path": "/" + unquote(path).lstrip("/")}) + + +@MCP_SERVER.resource( + "foxel://vfs/search/{query}", + name="vfs_search_resource", + title="VFS Search", + description="搜索文件;query 需要 URL 编码。", + mime_type="application/json", +) +async def vfs_search_resource(query: str) -> dict[str, Any]: + return await _tool_resource("vfs_search", {"q": unquote(query)}) + + +@MCP_SERVER.prompt(name="browse_path", title="Browse Path", description="生成浏览目录的推荐提示词。") +def browse_path_prompt(path: Annotated[str, Field(description="目标目录路径")]) -> list[dict[str, Any]]: + return [{"role": "user", "content": f"请先浏览目录 `{path}`,总结结构与关键文件。必要时调用 vfs_list_dir 与 vfs_stat。"}] + + +@MCP_SERVER.prompt(name="inspect_file", title="Inspect File", description="生成查看文件的推荐提示词。") +def inspect_file_prompt(path: Annotated[str, Field(description="目标文件路径")]) -> list[dict[str, Any]]: + return [{"role": "user", "content": f"请检查文件 `{path}` 的内容与用途。必要时调用 vfs_read_text。"}] + + +@MCP_SERVER.prompt(name="search_files", title="Search Files", description="生成搜索文件的推荐提示词。") +def search_files_prompt(query: Annotated[str, Field(description="搜索关键词")]) -> list[dict[str, Any]]: + return [{"role": "user", "content": f"请搜索与 `{query}` 相关的文件,并按相关性总结。必要时调用 vfs_search。"}] + + +@MCP_SERVER.prompt(name="edit_file_safely", title="Edit File Safely", description="生成安全修改文件的推荐提示词。") +def edit_file_safely_prompt(path: Annotated[str, Field(description="目标文件路径")]) -> list[dict[str, Any]]: + return [{"role": "user", "content": f"请先读取 `{path}`,解释拟修改点,再等待我确认后执行写入。"}] + + +@MCP_SERVER.prompt(name="run_processor", title="Run Processor", description="生成运行处理器的推荐提示词。") +def run_processor_prompt( + path: Annotated[str, Field(description="目标文件或目录路径")], + processor_type: Annotated[str, Field(description="处理器类型")], +) -> list[dict[str, Any]]: + return [{"role": "user", "content": f"请检查 `{path}` 是否适合运行处理器 `{processor_type}`,确认参数后再执行 processors_run。"}] + + +@MCP_SERVER.prompt(name="fetch_web_page", title="Fetch Web Page", description="生成抓取网页的推荐提示词。") +def fetch_web_page_prompt(url: Annotated[str, Field(description="目标网址")]) -> list[dict[str, Any]]: + return [{"role": "user", "content": f"请抓取网页 `{url}`,并总结标题、正文与关键链接。必要时调用 web_fetch。"}] + + +MCP_HTTP_APP = MCP_SERVER.streamable_http_app() + + +def loopback_httpx_client_factory(app): + def factory(headers: dict[str, str] | None = None, timeout=None, auth=None) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url=INTERNAL_MCP_BASE_URL.rstrip("/"), + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) + + return factory + + +async def create_loopback_mcp_headers(user: User | None, current_path: str | None = None) -> dict[str, str]: + headers: dict[str, str] = {} + if user is not None: + token = await AuthService.create_access_token( + {"sub": user.username}, + expires_delta=timedelta(minutes=5), + ) + headers["Authorization"] = f"Bearer {token}" + if current_path: + headers[CURRENT_PATH_HEADER] = current_path + return headers + + +@asynccontextmanager +async def mcp_client_session(user: User | None, current_path: str | None = None): + headers = await create_loopback_mcp_headers(user, current_path) + async with streamablehttp_client( + INTERNAL_MCP_BASE_URL, + headers=headers, + httpx_client_factory=loopback_httpx_client_factory(MCP_HTTP_APP), + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + yield session + + +def mcp_content_to_text(content: list[Any], structured_content: dict[str, Any] | None = None) -> str: + if structured_content is not None: + try: + return json.dumps(structured_content, ensure_ascii=False) + except TypeError: + pass + + text_parts: list[str] = [] + for item in content: + item_type = getattr(item, "type", None) + if item_type == "text": + text = getattr(item, "text", None) + if isinstance(text, str) and text: + text_parts.append(text) + if text_parts: + return "\n".join(text_parts) + return tool_result_to_content({"error": "empty_mcp_content"}) + + +def encode_resource_path(path: str) -> str: + return quote(path.lstrip("/"), safe="") diff --git a/domain/agent/service.py b/domain/agent/service.py index 31928fd..bcb513f 100644 --- a/domain/agent/service.py +++ b/domain/agent/service.py @@ -8,27 +8,27 @@ from fastapi import HTTPException from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream from domain.auth import User -from .tools import get_tool, openai_tools, tool_result_to_content -from .types import AgentChatRequest, PendingToolCall + +from .mcp import mcp_client_session, mcp_content_to_text +from .tools import tool_result_to_content +from .types import AgentChatRequest, PendingMcpCall -def _normalize_path(p: Optional[str]) -> Optional[str]: - if not p: +def _normalize_path(path: Optional[str]) -> Optional[str]: + if not path: return None - s = str(p).strip() - if not s: + value = str(path).strip().replace("\\", "/") + if not value: return None - s = s.replace("\\", "/") - if not s.startswith("/"): - s = "/" + s - s = s.rstrip("/") or "/" - return s + if not value.startswith("/"): + value = "/" + value + return value.rstrip("/") or "/" def _build_system_prompt(current_path: Optional[str]) -> str: lines = [ "你是 Foxel 的 AI 助手。", - "你可以通过工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。", + "你可以通过 MCP 工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。", "", "可用工具:", "- time:获取服务器当前时间(精确到秒,英文星期),支持 year/month/day/hour/minute/second 偏移。", @@ -60,13 +60,13 @@ def _build_system_prompt(current_path: Optional[str]) -> str: return "\n".join(lines) -def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]: - tool_calls = message.get("tool_calls") - if not isinstance(tool_calls, list): +def _ensure_mcp_call_ids(message: Dict[str, Any]) -> Dict[str, Any]: + mcp_calls = message.get("mcp_calls") + if not isinstance(mcp_calls, list): return message changed = False - for idx, call in enumerate(tool_calls): + for idx, call in enumerate(mcp_calls): if not isinstance(call, dict): continue call_id = call.get("id") @@ -76,57 +76,54 @@ def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]: changed = True if changed: - message["tool_calls"] = tool_calls + message["mcp_calls"] = mcp_calls return message -def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall: - call_id = str(tool_call.get("id") or "") - fn = tool_call.get("function") or {} - name = str((fn.get("name") if isinstance(fn, dict) else None) or "") - raw_args = fn.get("arguments") if isinstance(fn, dict) else None - arguments: Dict[str, Any] = {} - if isinstance(raw_args, str) and raw_args.strip(): - try: - parsed = json.loads(raw_args) - if isinstance(parsed, dict): - arguments = parsed - except json.JSONDecodeError: - arguments = {} - return PendingToolCall( - id=call_id, - name=name, +def _extract_pending(mcp_call: Dict[str, Any], requires_confirmation: bool) -> PendingMcpCall: + arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {} + return PendingMcpCall( + id=str(mcp_call.get("id") or ""), + name=str(mcp_call.get("name") or ""), arguments=arguments, requires_confirmation=requires_confirmation, ) -def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]: +def _find_last_assistant_mcp_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]: for idx in range(len(messages) - 1, -1, -1): msg = messages[idx] if not isinstance(msg, dict): continue if msg.get("role") != "assistant": continue - tool_calls = msg.get("tool_calls") - if isinstance(tool_calls, list) and tool_calls: + mcp_calls = msg.get("mcp_calls") + if isinstance(mcp_calls, list) and mcp_calls: return idx, msg raise HTTPException(status_code=400, detail="没有可确认的待执行操作") -def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]: +def _existing_mcp_result_ids(messages: List[Dict[str, Any]]) -> set[str]: ids: set[str] = set() for msg in messages: if not isinstance(msg, dict): continue if msg.get("role") != "tool": continue - tool_call_id = msg.get("tool_call_id") - if isinstance(tool_call_id, str) and tool_call_id.strip(): - ids.add(tool_call_id) + call_id = msg.get("mcp_call_id") + if isinstance(call_id, str) and call_id.strip(): + ids.add(call_id) return ids +def _tool_requires_confirmation(tool_descriptor: Dict[str, Any]) -> bool: + meta = tool_descriptor.get("meta") if isinstance(tool_descriptor.get("meta"), dict) else {} + if "requires_confirmation" in meta: + return bool(meta.get("requires_confirmation")) + annotations = tool_descriptor.get("annotations") if isinstance(tool_descriptor.get("annotations"), dict) else {} + return not bool(annotations.get("readOnlyHint")) + + async def _choose_chat_ability() -> str: tools_model = await AIProviderService.get_default_model("tools") return "tools" if tools_model else "chat" @@ -142,245 +139,91 @@ def _format_exc(exc: BaseException) -> str: return text if text else exc.__class__.__name__ +async def _list_mcp_tools(session) -> List[Dict[str, Any]]: + result = await session.list_tools() + tools: List[Dict[str, Any]] = [] + for item in result.tools: + annotations = getattr(item, "annotations", None) + meta = getattr(item, "meta", None) + tools.append( + { + "name": str(getattr(item, "name", "") or ""), + "description": str(getattr(item, "description", "") or ""), + "input_schema": getattr(item, "inputSchema", None) or {}, + "annotations": annotations.model_dump(exclude_none=True) if annotations is not None else {}, + "meta": meta if isinstance(meta, dict) else {}, + } + ) + return tools + + +async def _execute_mcp_call(session, name: str, arguments: Dict[str, Any]) -> str: + result = await session.call_tool(name, arguments) + return mcp_content_to_text(result.content, result.structuredContent) + + class AgentService: @classmethod async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]: history: List[Dict[str, Any]] = list(req.messages or []) current_path = _normalize_path(req.context.current_path if req.context else None) - system_prompt = _build_system_prompt(current_path) internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history - new_messages: List[Dict[str, Any]] = [] - pending: List[PendingToolCall] = [] + pending: List[PendingMcpCall] = [] - approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()} - rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()} + approved_ids = {i for i in (req.approved_mcp_call_ids or []) if isinstance(i, str) and i.strip()} + rejected_ids = {i for i in (req.rejected_mcp_call_ids or []) if isinstance(i, str) and i.strip()} - if approved_ids or rejected_ids: - _, last_call_msg = _find_last_assistant_tool_calls(internal_messages) - last_call_msg = _ensure_tool_call_ids(last_call_msg) - tool_calls = last_call_msg.get("tool_calls") or [] - call_map: Dict[str, Dict[str, Any]] = { - str(c.get("id")): c - for c in tool_calls - if isinstance(c, dict) and isinstance(c.get("id"), str) - } + async with mcp_client_session(user, current_path) as mcp_session: + tools_schema = await _list_mcp_tools(mcp_session) + tool_index = {tool["name"]: tool for tool in tools_schema if tool.get("name")} - existing_ids = _existing_tool_result_ids(internal_messages) - for call_id in approved_ids | rejected_ids: - if call_id in existing_ids: - continue - tool_call = call_map.get(call_id) - if not tool_call: - continue - fn = tool_call.get("function") or {} - name = fn.get("name") if isinstance(fn, dict) else None - args_raw = fn.get("arguments") if isinstance(fn, dict) else None - args: Dict[str, Any] = {} - if isinstance(args_raw, str) and args_raw.strip(): - try: - parsed = json.loads(args_raw) - if isinstance(parsed, dict): - args = parsed - except json.JSONDecodeError: - args = {} - - spec = get_tool(str(name or "")) - if call_id in rejected_ids: - content = tool_result_to_content({"canceled": True, "reason": "user_rejected"}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} - internal_messages.append(tool_msg) - new_messages.append(tool_msg) - continue - - if not spec: - content = tool_result_to_content({"error": f"unknown_tool: {name}"}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} - internal_messages.append(tool_msg) - new_messages.append(tool_msg) - continue - - try: - result = await spec.handler(args) - content = tool_result_to_content(result) - except Exception as exc: # noqa: BLE001 - content = tool_result_to_content({"error": str(exc)}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} - internal_messages.append(tool_msg) - new_messages.append(tool_msg) - - tools_schema = openai_tools() - ability = await _choose_chat_ability() - max_loops = 4 - - for _ in range(max_loops): - try: - assistant = await chat_completion( - internal_messages, - ability=ability, - tools=tools_schema, - tool_choice="auto", - timeout=60.0, - ) - except MissingModelError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - except httpx.HTTPStatusError as exc: - raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc - except httpx.RequestError as exc: - raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc - - assistant = _ensure_tool_call_ids(assistant) - internal_messages.append(assistant) - new_messages.append(assistant) - - tool_calls = assistant.get("tool_calls") - if not isinstance(tool_calls, list) or not tool_calls: - break - - pending = [] - for call in tool_calls: - if not isinstance(call, dict): - continue - call_id = str(call.get("id") or "") - fn = call.get("function") or {} - name = fn.get("name") if isinstance(fn, dict) else None - args_raw = fn.get("arguments") if isinstance(fn, dict) else None - args: Dict[str, Any] = {} - if isinstance(args_raw, str) and args_raw.strip(): - try: - parsed = json.loads(args_raw) - if isinstance(parsed, dict): - args = parsed - except json.JSONDecodeError: - args = {} - - spec = get_tool(str(name or "")) - if not spec: - content = tool_result_to_content({"error": f"unknown_tool: {name}"}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} - internal_messages.append(tool_msg) - new_messages.append(tool_msg) - continue - - if spec.requires_confirmation and not req.auto_execute: - pending.append(_extract_pending(call, True)) - continue - - try: - result = await spec.handler(args) - content = tool_result_to_content(result) - except Exception as exc: # noqa: BLE001 - content = tool_result_to_content({"error": str(exc)}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} - internal_messages.append(tool_msg) - new_messages.append(tool_msg) - - if pending: - break - - payload: Dict[str, Any] = {"messages": new_messages} - if pending: - payload["pending_tool_calls"] = [p.model_dump() for p in pending] - return payload - - @classmethod - async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]): - history: List[Dict[str, Any]] = list(req.messages or []) - current_path = _normalize_path(req.context.current_path if req.context else None) - - system_prompt = _build_system_prompt(current_path) - internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history - - new_messages: List[Dict[str, Any]] = [] - pending: List[PendingToolCall] = [] - - approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()} - rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()} - - try: if approved_ids or rejected_ids: - _, last_call_msg = _find_last_assistant_tool_calls(internal_messages) - last_call_msg = _ensure_tool_call_ids(last_call_msg) - tool_calls = last_call_msg.get("tool_calls") or [] + _, last_call_msg = _find_last_assistant_mcp_calls(internal_messages) + last_call_msg = _ensure_mcp_call_ids(last_call_msg) + mcp_calls = last_call_msg.get("mcp_calls") or [] call_map: Dict[str, Dict[str, Any]] = { - str(c.get("id")): c - for c in tool_calls - if isinstance(c, dict) and isinstance(c.get("id"), str) + str(call.get("id")): call + for call in mcp_calls + if isinstance(call, dict) and isinstance(call.get("id"), str) } - existing_ids = _existing_tool_result_ids(internal_messages) + existing_ids = _existing_mcp_result_ids(internal_messages) for call_id in approved_ids | rejected_ids: if call_id in existing_ids: continue - tool_call = call_map.get(call_id) - if not tool_call: + mcp_call = call_map.get(call_id) + if not mcp_call: continue - fn = tool_call.get("function") or {} - name = fn.get("name") if isinstance(fn, dict) else None - args_raw = fn.get("arguments") if isinstance(fn, dict) else None - args: Dict[str, Any] = {} - if isinstance(args_raw, str) and args_raw.strip(): - try: - parsed = json.loads(args_raw) - if isinstance(parsed, dict): - args = parsed - except json.JSONDecodeError: - args = {} + name = str(mcp_call.get("name") or "") + arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {} + tool_desc = tool_index.get(name) - spec = get_tool(str(name or "")) if call_id in rejected_ids: content = tool_result_to_content({"canceled": True, "reason": "user_rejected"}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} - internal_messages.append(tool_msg) - new_messages.append(tool_msg) - yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg}) - continue - - if not spec: + elif not tool_desc: content = tool_result_to_content({"error": f"unknown_tool: {name}"}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} - internal_messages.append(tool_msg) - new_messages.append(tool_msg) - yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg}) - continue - - yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name}) - try: - result = await spec.handler(args) - content = tool_result_to_content(result) - except Exception as exc: # noqa: BLE001 - content = tool_result_to_content({"error": str(exc)}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + else: + try: + content = await _execute_mcp_call(mcp_session, name, arguments) + except Exception as exc: # noqa: BLE001 + content = tool_result_to_content({"error": str(exc)}) + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} internal_messages.append(tool_msg) new_messages.append(tool_msg) - yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg}) - tools_schema = openai_tools() ability = await _choose_chat_ability() - max_loops = 4 - for _ in range(max_loops): - assistant_event_id = uuid.uuid4().hex - yield _sse("assistant_start", {"id": assistant_event_id}) - - assistant_message: Dict[str, Any] | None = None + for _ in range(8): try: - async for event in chat_completion_stream( + assistant = await chat_completion( internal_messages, ability=ability, tools=tools_schema, tool_choice="auto", timeout=60.0, - ): - if event.get("type") == "delta": - delta = event.get("delta") - if isinstance(delta, str) and delta: - yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta}) - elif event.get("type") == "message": - msg = event.get("message") - if isinstance(msg, dict): - assistant_message = msg + ) except MissingModelError as exc: raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc except httpx.HTTPStatusError as exc: @@ -388,66 +231,196 @@ class AgentService: except httpx.RequestError as exc: raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc - if not assistant_message: - assistant_message = {"role": "assistant", "content": ""} + assistant = _ensure_mcp_call_ids(assistant if isinstance(assistant, dict) else {"role": "assistant", "content": ""}) + internal_messages.append(assistant) + new_messages.append(assistant) - assistant_message = _ensure_tool_call_ids(assistant_message) - internal_messages.append(assistant_message) - new_messages.append(assistant_message) - yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message}) - - tool_calls = assistant_message.get("tool_calls") - if not isinstance(tool_calls, list) or not tool_calls: + mcp_calls = assistant.get("mcp_calls") + if not isinstance(mcp_calls, list) or not mcp_calls: break pending = [] - for call in tool_calls: + for call in mcp_calls: if not isinstance(call, dict): continue call_id = str(call.get("id") or "") - fn = call.get("function") or {} - name = fn.get("name") if isinstance(fn, dict) else None - args_raw = fn.get("arguments") if isinstance(fn, dict) else None - args: Dict[str, Any] = {} - if isinstance(args_raw, str) and args_raw.strip(): - try: - parsed = json.loads(args_raw) - if isinstance(parsed, dict): - args = parsed - except json.JSONDecodeError: - args = {} + name = str(call.get("name") or "") + arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {} + tool_desc = tool_index.get(name) - spec = get_tool(str(name or "")) - if not spec: + if not tool_desc: content = tool_result_to_content({"error": f"unknown_tool: {name}"}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} internal_messages.append(tool_msg) new_messages.append(tool_msg) - yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg}) continue - if spec.requires_confirmation and not req.auto_execute: + if _tool_requires_confirmation(tool_desc) and not req.auto_execute: pending.append(_extract_pending(call, True)) continue - yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name}) try: - result = await spec.handler(args) - content = tool_result_to_content(result) + content = await _execute_mcp_call(mcp_session, name, arguments) except Exception as exc: # noqa: BLE001 content = tool_result_to_content({"error": str(exc)}) - tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} internal_messages.append(tool_msg) new_messages.append(tool_msg) - yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg}) if pending: - yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]}) break + payload: Dict[str, Any] = {"messages": new_messages} + if pending: + payload["pending_mcp_calls"] = [item.model_dump() for item in pending] + return payload + + @classmethod + async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]): + history: List[Dict[str, Any]] = list(req.messages or []) + current_path = _normalize_path(req.context.current_path if req.context else None) + system_prompt = _build_system_prompt(current_path) + internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history + new_messages: List[Dict[str, Any]] = [] + pending: List[PendingMcpCall] = [] + + approved_ids = {i for i in (req.approved_mcp_call_ids or []) if isinstance(i, str) and i.strip()} + rejected_ids = {i for i in (req.rejected_mcp_call_ids or []) if isinstance(i, str) and i.strip()} + + try: + async with mcp_client_session(user, current_path) as mcp_session: + tools_schema = await _list_mcp_tools(mcp_session) + tool_index = {tool["name"]: tool for tool in tools_schema if tool.get("name")} + + if approved_ids or rejected_ids: + _, last_call_msg = _find_last_assistant_mcp_calls(internal_messages) + last_call_msg = _ensure_mcp_call_ids(last_call_msg) + mcp_calls = last_call_msg.get("mcp_calls") or [] + call_map: Dict[str, Dict[str, Any]] = { + str(call.get("id")): call + for call in mcp_calls + if isinstance(call, dict) and isinstance(call.get("id"), str) + } + + existing_ids = _existing_mcp_result_ids(internal_messages) + for call_id in approved_ids | rejected_ids: + if call_id in existing_ids: + continue + mcp_call = call_map.get(call_id) + if not mcp_call: + continue + + name = str(mcp_call.get("name") or "") + arguments = mcp_call.get("arguments") if isinstance(mcp_call.get("arguments"), dict) else {} + tool_desc = tool_index.get(name) + + if call_id in rejected_ids: + content = tool_result_to_content({"canceled": True, "reason": "user_rejected"}) + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg}) + continue + + if not tool_desc: + content = tool_result_to_content({"error": f"unknown_tool: {name}"}) + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg}) + continue + + yield _sse("mcp_call_start", {"mcp_call_id": call_id, "name": name}) + try: + content = await _execute_mcp_call(mcp_session, name, arguments) + except Exception as exc: # noqa: BLE001 + content = tool_result_to_content({"error": str(exc)}) + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg}) + + ability = await _choose_chat_ability() + + for _ in range(8): + assistant_event_id = str(uuid.uuid4()) + yield _sse("assistant_start", {"id": assistant_event_id}) + + assistant_message: Dict[str, Any] | None = None + try: + async for event in chat_completion_stream( + internal_messages, + ability=ability, + tools=tools_schema, + tool_choice="auto", + timeout=60.0, + ): + event_type = event.get("type") + if event_type == "delta": + delta = event.get("delta") + if isinstance(delta, str) and delta: + yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta}) + elif event_type == "message": + msg = event.get("message") + if isinstance(msg, dict): + assistant_message = msg + except MissingModelError as exc: + raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc + except httpx.HTTPStatusError as exc: + raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc + except httpx.RequestError as exc: + raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc + + if not assistant_message: + assistant_message = {"role": "assistant", "content": ""} + + assistant_message = _ensure_mcp_call_ids(assistant_message) + internal_messages.append(assistant_message) + new_messages.append(assistant_message) + yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message}) + + mcp_calls = assistant_message.get("mcp_calls") + if not isinstance(mcp_calls, list) or not mcp_calls: + break + + pending = [] + for call in mcp_calls: + if not isinstance(call, dict): + continue + call_id = str(call.get("id") or "") + name = str(call.get("name") or "") + arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {} + tool_desc = tool_index.get(name) + + if not tool_desc: + content = tool_result_to_content({"error": f"unknown_tool: {name}"}) + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg}) + continue + + if _tool_requires_confirmation(tool_desc) and not req.auto_execute: + pending.append(_extract_pending(call, True)) + continue + + yield _sse("mcp_call_start", {"mcp_call_id": call_id, "name": name}) + try: + content = await _execute_mcp_call(mcp_session, name, arguments) + except Exception as exc: # noqa: BLE001 + content = tool_result_to_content({"error": str(exc)}) + tool_msg = {"role": "tool", "mcp_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("mcp_call_end", {"mcp_call_id": call_id, "name": name, "message": tool_msg}) + + if pending: + yield _sse("pending", {"pending_mcp_calls": [item.model_dump() for item in pending]}) + break + payload: Dict[str, Any] = {"messages": new_messages} if pending: - payload["pending_tool_calls"] = [p.model_dump() for p in pending] + payload["pending_mcp_calls"] = [item.model_dump() for item in pending] yield _sse("done", payload) except asyncio.CancelledError: @@ -460,13 +433,11 @@ class AgentService: new_messages.append({"role": "assistant", "content": content}) payload: Dict[str, Any] = {"messages": new_messages} if pending: - payload["pending_tool_calls"] = [p.model_dump() for p in pending] + payload["pending_mcp_calls"] = [item.model_dump() for item in pending] yield _sse("done", payload) - return except Exception as exc: # noqa: BLE001 new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"}) payload: Dict[str, Any] = {"messages": new_messages} if pending: - payload["pending_tool_calls"] = [p.model_dump() for p in pending] + payload["pending_mcp_calls"] = [item.model_dump() for item in pending] yield _sse("done", payload) - return diff --git a/domain/agent/tools/__init__.py b/domain/agent/tools/__init__.py index 71537d6..ad6b362 100644 --- a/domain/agent/tools/__init__.py +++ b/domain/agent/tools/__init__.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from .base import ToolSpec, tool_result_to_content +from .base import McpToolDescriptor, ToolSpec, tool_result_to_content, tool_spec_to_mcp_descriptor from .processors import TOOLS as PROCESSOR_TOOLS from .time import TOOLS as TIME_TOOLS from .vfs import TOOLS as VFS_TOOLS @@ -15,23 +15,19 @@ def get_tool(name: str) -> Optional[ToolSpec]: return TOOLS.get(name) -def openai_tools() -> List[Dict[str, Any]]: - out: List[Dict[str, Any]] = [] - for spec in TOOLS.values(): - out.append({ - "type": "function", - "function": { - "name": spec.name, - "description": spec.description, - "parameters": spec.parameters, - }, - }) - return out +def list_tool_specs() -> List[ToolSpec]: + return list(TOOLS.values()) + + +def mcp_tool_descriptors() -> List[McpToolDescriptor]: + return [tool_spec_to_mcp_descriptor(spec) for spec in TOOLS.values()] __all__ = [ + "McpToolDescriptor", "ToolSpec", "get_tool", - "openai_tools", + "list_tool_specs", + "mcp_tool_descriptors", "tool_result_to_content", ] diff --git a/domain/agent/tools/base.py b/domain/agent/tools/base.py index 402c628..9ab550d 100644 --- a/domain/agent/tools/base.py +++ b/domain/agent/tools/base.py @@ -3,6 +3,16 @@ from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, List, Optional +@dataclass(frozen=True) +class McpToolDescriptor: + name: str + description: str + input_schema: Dict[str, Any] + annotations: Dict[str, Any] + meta: Dict[str, Any] + requires_confirmation: bool + + @dataclass(frozen=True) class ToolSpec: name: str @@ -141,9 +151,31 @@ def _normalize_tool_result(result: Any) -> Dict[str, Any]: return {"ok": True, "summary": summary, "view": view, "data": result} +def normalize_tool_result(result: Any) -> Dict[str, Any]: + return _normalize_tool_result(result) + + def tool_result_to_content(result: Any) -> str: - payload = _normalize_tool_result(result) + payload = normalize_tool_result(result) try: return json.dumps(payload, ensure_ascii=False, default=str) except TypeError: return json.dumps({"ok": False, "summary": "error", "view": {"type": "error", "message": "error"}}, ensure_ascii=False) + + +def tool_spec_to_mcp_descriptor(spec: ToolSpec) -> McpToolDescriptor: + read_only = not spec.requires_confirmation + annotations: Dict[str, Any] = { + "readOnlyHint": read_only, + "destructiveHint": bool(spec.requires_confirmation), + } + if spec.name == "web_fetch": + annotations["openWorldHint"] = True + return McpToolDescriptor( + name=spec.name, + description=spec.description, + input_schema=spec.parameters, + annotations=annotations, + meta={"requires_confirmation": spec.requires_confirmation}, + requires_confirmation=spec.requires_confirmation, + ) diff --git a/domain/agent/types.py b/domain/agent/types.py index 7513ba2..0e695c7 100644 --- a/domain/agent/types.py +++ b/domain/agent/types.py @@ -10,14 +10,19 @@ class AgentChatContext(BaseModel): class AgentChatRequest(BaseModel): messages: List[Dict[str, Any]] = Field(default_factory=list) auto_execute: bool = False - approved_tool_call_ids: List[str] = Field(default_factory=list) - rejected_tool_call_ids: List[str] = Field(default_factory=list) + approved_mcp_call_ids: List[str] = Field(default_factory=list) + rejected_mcp_call_ids: List[str] = Field(default_factory=list) context: Optional[AgentChatContext] = None -class PendingToolCall(BaseModel): +class McpCall(BaseModel): + id: str + name: str + arguments: Dict[str, Any] = Field(default_factory=dict) + + +class PendingMcpCall(BaseModel): id: str name: str arguments: Dict[str, Any] = Field(default_factory=dict) requires_confirmation: bool = True - diff --git a/domain/ai/inference.py b/domain/ai/inference.py index 6e74981..2f11a50 100644 --- a/domain/ai/inference.py +++ b/domain/ai/inference.py @@ -15,6 +15,102 @@ class MissingModelError(RuntimeError): pass +def _mcp_tools_to_openai_wire(tools: List[Dict[str, Any]] | None) -> List[Dict[str, Any]] | None: + if not tools: + return None + out: List[Dict[str, Any]] = [] + for tool in tools: + if not isinstance(tool, dict): + continue + name = tool.get("name") + if not isinstance(name, str) or not name.strip(): + continue + out.append( + { + "type": "function", + "function": { + "name": name, + "description": str(tool.get("description") or ""), + "parameters": tool.get("input_schema") if isinstance(tool.get("input_schema"), dict) else {}, + }, + } + ) + return out + + +def _mcp_messages_to_openai_wire(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + out: List[Dict[str, Any]] = [] + for message in messages: + if not isinstance(message, dict): + continue + item = dict(message) + mcp_call_id = item.pop("mcp_call_id", None) + if isinstance(mcp_call_id, str) and mcp_call_id.strip(): + item["tool_call_id"] = mcp_call_id + + mcp_calls = item.pop("mcp_calls", None) + if isinstance(mcp_calls, list): + tool_calls: List[Dict[str, Any]] = [] + for idx, call in enumerate(mcp_calls): + if not isinstance(call, dict): + continue + name = call.get("name") + if not isinstance(name, str) or not name.strip(): + continue + arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {} + tool_calls.append( + { + "id": str(call.get("id") or f"call_{idx}"), + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(arguments, ensure_ascii=False), + }, + } + ) + if tool_calls: + item["tool_calls"] = tool_calls + out.append(item) + return out + + +def _openai_wire_message_to_mcp(message: Dict[str, Any]) -> Dict[str, Any]: + out = dict(message) + tool_call_id = out.pop("tool_call_id", None) + if isinstance(tool_call_id, str) and tool_call_id.strip(): + out["mcp_call_id"] = tool_call_id + + tool_calls = out.pop("tool_calls", None) + if isinstance(tool_calls, list): + mcp_calls: List[Dict[str, Any]] = [] + for idx, call in enumerate(tool_calls): + if not isinstance(call, dict): + continue + fn = call.get("function") if isinstance(call.get("function"), dict) else {} + name = fn.get("name") + if not isinstance(name, str) or not name.strip(): + continue + arguments: Dict[str, Any] = {} + raw_args = fn.get("arguments") + if isinstance(raw_args, str) and raw_args.strip(): + try: + parsed = json.loads(raw_args) + if isinstance(parsed, dict): + arguments = parsed + except json.JSONDecodeError: + arguments = {} + mcp_calls.append( + { + "id": str(call.get("id") or f"call_{idx}"), + "name": name, + "arguments": arguments, + } + ) + if mcp_calls: + out["mcp_calls"] = mcp_calls + return out + + async def describe_image_base64(base64_image: str, detail: str = "high") -> str: """ 传入 base64 图片并返回描述文本。缺省时返回错误提示。 @@ -939,34 +1035,39 @@ async def chat_completion( ) -> Dict[str, Any]: model, provider = await _require_model(ability) fmt = str(provider.api_format or "").lower() + wire_messages = _mcp_messages_to_openai_wire(messages) + wire_tools = _mcp_tools_to_openai_wire(tools) if fmt == "openai": - return await _chat_with_openai( + result = await _chat_with_openai( provider, model, - messages, - tools=tools, + wire_messages, + tools=wire_tools, tool_choice=tool_choice, temperature=temperature, timeout=timeout, ) + return _openai_wire_message_to_mcp(result) if fmt == "anthropic": - return await _chat_with_anthropic( + result = await _chat_with_anthropic( provider, model, - messages, - tools=tools, + wire_messages, + tools=wire_tools, temperature=temperature, timeout=timeout, ) + return _openai_wire_message_to_mcp(result) if fmt == "ollama": - return await _chat_with_ollama( + result = await _chat_with_ollama( provider, model, - messages, - tools=tools, + wire_messages, + tools=wire_tools, temperature=temperature, timeout=timeout, ) + return _openai_wire_message_to_mcp(result) raise MissingModelError(f"当前不支持该对话模型接口类型: {provider.api_format}") @@ -1016,38 +1117,49 @@ async def chat_completion_stream( ) -> AsyncIterator[Dict[str, Any]]: model, provider = await _require_model(ability) fmt = str(provider.api_format or "").lower() + wire_messages = _mcp_messages_to_openai_wire(messages) + wire_tools = _mcp_tools_to_openai_wire(tools) if fmt == "openai": async for event in _chat_stream_with_openai( provider, model, - messages, - tools=tools, + wire_messages, + tools=wire_tools, tool_choice=tool_choice, temperature=temperature, timeout=timeout, ): + if event.get("type") == "message" and isinstance(event.get("message"), dict): + yield {**event, "message": _openai_wire_message_to_mcp(event["message"])} + continue yield event return if fmt == "anthropic": async for event in _chat_stream_with_anthropic( provider, model, - messages, - tools=tools, + wire_messages, + tools=wire_tools, temperature=temperature, timeout=timeout, ): + if event.get("type") == "message" and isinstance(event.get("message"), dict): + yield {**event, "message": _openai_wire_message_to_mcp(event["message"])} + continue yield event return if fmt == "ollama": async for event in _chat_stream_with_ollama( provider, model, - messages, - tools=tools, + wire_messages, + tools=wire_tools, temperature=temperature, timeout=timeout, ): + if event.get("type") == "message" and isinstance(event.get("message"), dict): + yield {**event, "message": _openai_wire_message_to_mcp(event["message"])} + continue yield event return raise MissingModelError(f"当前不支持该对话模型接口类型: {provider.api_format}") diff --git a/main.py b/main.py index f57f9be..1b84210 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ from pathlib import Path from contextlib import asynccontextmanager from domain.adapters import runtime_registry +from domain.agent.mcp import MCP_HTTP_APP from domain.config import ConfigService, VERSION from db.session import close_db, init_db from api.routers import include_routers @@ -80,12 +81,13 @@ async def lifespan(app: FastAPI): # 在所有路由加载完成后,挂载静态文件服务(放在最后以避免覆盖 API 路由) app.mount("/", SPAStaticFiles(directory="web/dist", html=True, check_dir=False), name="static") - try: - yield - finally: - await task_scheduler.stop() - await task_queue_service.stop_worker() - await close_db() + async with MCP_HTTP_APP.router.lifespan_context(MCP_HTTP_APP): + try: + yield + finally: + await task_scheduler.stop() + await task_queue_service.stop_worker() + await close_db() def create_app() -> FastAPI: @@ -95,6 +97,7 @@ def create_app() -> FastAPI: lifespan=lifespan, ) include_routers(app) + app.mount("/api/mcp", MCP_HTTP_APP, name="mcp") app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(httpx.HTTPStatusError, httpx_exception_handler) diff --git a/pyproject.toml b/pyproject.toml index 188376e..e4279b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "bcrypt>=5.0.0", "croniter>=6.0.0", "fastapi>=0.127.0", + "mcp>=1.26.0", "paramiko>=4.0.0", "pillow>=12.0.0", "pydantic[email]>=2.12.5", diff --git a/uv.lock b/uv.lock index ce5a2e1..bbae255 100644 --- a/uv.lock +++ b/uv.lock @@ -445,6 +445,7 @@ dependencies = [ { name = "bcrypt" }, { name = "croniter" }, { name = "fastapi" }, + { name = "mcp" }, { name = "paramiko" }, { name = "pillow" }, { name = "pydantic", extra = ["email"] }, @@ -466,6 +467,7 @@ requires-dist = [ { name = "bcrypt", specifier = ">=5.0.0" }, { name = "croniter", specifier = ">=6.0.0" }, { name = "fastapi", specifier = ">=0.127.0" }, + { name = "mcp", specifier = ">=1.26.0" }, { name = "paramiko", specifier = ">=4.0.0" }, { name = "pillow", specifier = ">=12.0.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.12.5" }, @@ -607,6 +609,15 @@ http2 = [ { name = "h2" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "hyperframe" version = "6.1.0" @@ -652,6 +663,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, ] +[[package]] +name = "jsonschema" +version = "4.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "jsonschema-specifications" }, + { name = "referencing" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/fc/e067678238fa451312d4c62bf6e6cf5ec56375422aee02f9cb5f909b3047/jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326", size = 366583, upload-time = "2026-01-07T13:41:07.246Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce", size = 90630, upload-time = "2026-01-07T13:41:05.306Z" }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, +] + +[[package]] +name = "mcp" +version = "1.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/6d/62e76bbb8144d6ed86e202b5edd8a4cb631e7c8130f3f4893c3f90262b10/mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66", size = 608005, upload-time = "2026-01-24T19:40:32.468Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" }, +] + [[package]] name = "milvus-lite" version = "2.5.1" @@ -989,6 +1052,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.13.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/6d/fffca34caecc4a3f97bda81b2098da5e8ab7efc9a66e819074a11955d87e/pydantic_settings-2.13.1.tar.gz", hash = "sha256:b4c11847b15237fb0171e1462bf540e294affb9b86db4d9aa5c01730bdbe4025", size = 223826, upload-time = "2026-02-19T13:45:08.055Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, +] + [[package]] name = "pyjwt" version = "2.11.0" @@ -998,6 +1075,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, ] +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pymilvus" version = "2.6.8" @@ -1141,6 +1223,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/13/8ce16f808297e16968269de44a14f4fef19b64d9766be1d6ba5ba78b579d/qdrant_client-1.16.2-py3-none-any.whl", hash = "sha256:442c7ef32ae0f005e88b5d3c0783c63d4912b97ae756eb5e052523be682f17d3", size = 377186, upload-time = "2025-12-12T10:58:29.282Z" }, ] +[[package]] +name = "referencing" +version = "0.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" }, +] + +[[package]] +name = "rpds-py" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/81/dad16382ebbd3d0e0328776d8fd7ca94220e4fa0798d1dc5e7da48cb3201/rpds_py-0.30.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:68f19c879420aa08f61203801423f6cd5ac5f0ac4ac82a2368a9fcd6a9a075e0", size = 362099, upload-time = "2025-11-30T20:23:27.316Z" }, + { url = "https://files.pythonhosted.org/packages/2b/60/19f7884db5d5603edf3c6bce35408f45ad3e97e10007df0e17dd57af18f8/rpds_py-0.30.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ec7c4490c672c1a0389d319b3a9cfcd098dcdc4783991553c332a15acf7249be", size = 353192, upload-time = "2025-11-30T20:23:29.151Z" }, + { url = "https://files.pythonhosted.org/packages/bf/c4/76eb0e1e72d1a9c4703c69607cec123c29028bff28ce41588792417098ac/rpds_py-0.30.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f251c812357a3fed308d684a5079ddfb9d933860fc6de89f2b7ab00da481e65f", size = 384080, upload-time = "2025-11-30T20:23:30.785Z" }, + { url = "https://files.pythonhosted.org/packages/72/87/87ea665e92f3298d1b26d78814721dc39ed8d2c74b86e83348d6b48a6f31/rpds_py-0.30.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac98b175585ecf4c0348fd7b29c3864bda53b805c773cbf7bfdaffc8070c976f", size = 394841, upload-time = "2025-11-30T20:23:32.209Z" }, + { url = "https://files.pythonhosted.org/packages/77/ad/7783a89ca0587c15dcbf139b4a8364a872a25f861bdb88ed99f9b0dec985/rpds_py-0.30.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3e62880792319dbeb7eb866547f2e35973289e7d5696c6e295476448f5b63c87", size = 516670, upload-time = "2025-11-30T20:23:33.742Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3c/2882bdac942bd2172f3da574eab16f309ae10a3925644e969536553cb4ee/rpds_py-0.30.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4e7fc54e0900ab35d041b0601431b0a0eb495f0851a0639b6ef90f7741b39a18", size = 408005, upload-time = "2025-11-30T20:23:35.253Z" }, + { url = "https://files.pythonhosted.org/packages/ce/81/9a91c0111ce1758c92516a3e44776920b579d9a7c09b2b06b642d4de3f0f/rpds_py-0.30.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47e77dc9822d3ad616c3d5759ea5631a75e5809d5a28707744ef79d7a1bcfcad", size = 382112, upload-time = "2025-11-30T20:23:36.842Z" }, + { url = "https://files.pythonhosted.org/packages/cf/8e/1da49d4a107027e5fbc64daeab96a0706361a2918da10cb41769244b805d/rpds_py-0.30.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:b4dc1a6ff022ff85ecafef7979a2c6eb423430e05f1165d6688234e62ba99a07", size = 399049, upload-time = "2025-11-30T20:23:38.343Z" }, + { url = "https://files.pythonhosted.org/packages/df/5a/7ee239b1aa48a127570ec03becbb29c9d5a9eb092febbd1699d567cae859/rpds_py-0.30.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4559c972db3a360808309e06a74628b95eaccbf961c335c8fe0d590cf587456f", size = 415661, upload-time = "2025-11-30T20:23:40.263Z" }, + { url = "https://files.pythonhosted.org/packages/70/ea/caa143cf6b772f823bc7929a45da1fa83569ee49b11d18d0ada7f5ee6fd6/rpds_py-0.30.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:0ed177ed9bded28f8deb6ab40c183cd1192aa0de40c12f38be4d59cd33cb5c65", size = 565606, upload-time = "2025-11-30T20:23:42.186Z" }, + { url = "https://files.pythonhosted.org/packages/64/91/ac20ba2d69303f961ad8cf55bf7dbdb4763f627291ba3d0d7d67333cced9/rpds_py-0.30.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:ad1fa8db769b76ea911cb4e10f049d80bf518c104f15b3edb2371cc65375c46f", size = 591126, upload-time = "2025-11-30T20:23:44.086Z" }, + { url = "https://files.pythonhosted.org/packages/21/20/7ff5f3c8b00c8a95f75985128c26ba44503fb35b8e0259d812766ea966c7/rpds_py-0.30.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:46e83c697b1f1c72b50e5ee5adb4353eef7406fb3f2043d64c33f20ad1c2fc53", size = 553371, upload-time = "2025-11-30T20:23:46.004Z" }, + { url = "https://files.pythonhosted.org/packages/72/c7/81dadd7b27c8ee391c132a6b192111ca58d866577ce2d9b0ca157552cce0/rpds_py-0.30.0-cp314-cp314-win32.whl", hash = "sha256:ee454b2a007d57363c2dfd5b6ca4a5d7e2c518938f8ed3b706e37e5d470801ed", size = 215298, upload-time = "2025-11-30T20:23:47.696Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d2/1aaac33287e8cfb07aab2e6b8ac1deca62f6f65411344f1433c55e6f3eb8/rpds_py-0.30.0-cp314-cp314-win_amd64.whl", hash = "sha256:95f0802447ac2d10bcc69f6dc28fe95fdf17940367b21d34e34c737870758950", size = 228604, upload-time = "2025-11-30T20:23:49.501Z" }, + { url = "https://files.pythonhosted.org/packages/e8/95/ab005315818cc519ad074cb7784dae60d939163108bd2b394e60dc7b5461/rpds_py-0.30.0-cp314-cp314-win_arm64.whl", hash = "sha256:613aa4771c99f03346e54c3f038e4cc574ac09a3ddfb0e8878487335e96dead6", size = 222391, upload-time = "2025-11-30T20:23:50.96Z" }, + { url = "https://files.pythonhosted.org/packages/9e/68/154fe0194d83b973cdedcdcc88947a2752411165930182ae41d983dcefa6/rpds_py-0.30.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:7e6ecfcb62edfd632e56983964e6884851786443739dbfe3582947e87274f7cb", size = 364868, upload-time = "2025-11-30T20:23:52.494Z" }, + { url = "https://files.pythonhosted.org/packages/83/69/8bbc8b07ec854d92a8b75668c24d2abcb1719ebf890f5604c61c9369a16f/rpds_py-0.30.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a1d0bc22a7cdc173fedebb73ef81e07faef93692b8c1ad3733b67e31e1b6e1b8", size = 353747, upload-time = "2025-11-30T20:23:54.036Z" }, + { url = "https://files.pythonhosted.org/packages/ab/00/ba2e50183dbd9abcce9497fa5149c62b4ff3e22d338a30d690f9af970561/rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d08f00679177226c4cb8c5265012eea897c8ca3b93f429e546600c971bcbae7", size = 383795, upload-time = "2025-11-30T20:23:55.556Z" }, + { url = "https://files.pythonhosted.org/packages/05/6f/86f0272b84926bcb0e4c972262f54223e8ecc556b3224d281e6598fc9268/rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5965af57d5848192c13534f90f9dd16464f3c37aaf166cc1da1cae1fd5a34898", size = 393330, upload-time = "2025-11-30T20:23:57.033Z" }, + { url = "https://files.pythonhosted.org/packages/cb/e9/0e02bb2e6dc63d212641da45df2b0bf29699d01715913e0d0f017ee29438/rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a4e86e34e9ab6b667c27f3211ca48f73dba7cd3d90f8d5b11be56e5dbc3fb4e", size = 518194, upload-time = "2025-11-30T20:23:58.637Z" }, + { url = "https://files.pythonhosted.org/packages/ee/ca/be7bca14cf21513bdf9c0606aba17d1f389ea2b6987035eb4f62bd923f25/rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d3e6b26f2c785d65cc25ef1e5267ccbe1b069c5c21b8cc724efee290554419", size = 408340, upload-time = "2025-11-30T20:24:00.2Z" }, + { url = "https://files.pythonhosted.org/packages/c2/c7/736e00ebf39ed81d75544c0da6ef7b0998f8201b369acf842f9a90dc8fce/rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:626a7433c34566535b6e56a1b39a7b17ba961e97ce3b80ec62e6f1312c025551", size = 383765, upload-time = "2025-11-30T20:24:01.759Z" }, + { url = "https://files.pythonhosted.org/packages/4a/3f/da50dfde9956aaf365c4adc9533b100008ed31aea635f2b8d7b627e25b49/rpds_py-0.30.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:acd7eb3f4471577b9b5a41baf02a978e8bdeb08b4b355273994f8b87032000a8", size = 396834, upload-time = "2025-11-30T20:24:03.687Z" }, + { url = "https://files.pythonhosted.org/packages/4e/00/34bcc2565b6020eab2623349efbdec810676ad571995911f1abdae62a3a0/rpds_py-0.30.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fe5fa731a1fa8a0a56b0977413f8cacac1768dad38d16b3a296712709476fbd5", size = 415470, upload-time = "2025-11-30T20:24:05.232Z" }, + { url = "https://files.pythonhosted.org/packages/8c/28/882e72b5b3e6f718d5453bd4d0d9cf8df36fddeb4ddbbab17869d5868616/rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:74a3243a411126362712ee1524dfc90c650a503502f135d54d1b352bd01f2404", size = 565630, upload-time = "2025-11-30T20:24:06.878Z" }, + { url = "https://files.pythonhosted.org/packages/3b/97/04a65539c17692de5b85c6e293520fd01317fd878ea1995f0367d4532fb1/rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:3e8eeb0544f2eb0d2581774be4c3410356eba189529a6b3e36bbbf9696175856", size = 591148, upload-time = "2025-11-30T20:24:08.445Z" }, + { url = "https://files.pythonhosted.org/packages/85/70/92482ccffb96f5441aab93e26c4d66489eb599efdcf96fad90c14bbfb976/rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:dbd936cde57abfee19ab3213cf9c26be06d60750e60a8e4dd85d1ab12c8b1f40", size = 556030, upload-time = "2025-11-30T20:24:10.956Z" }, + { url = "https://files.pythonhosted.org/packages/20/53/7c7e784abfa500a2b6b583b147ee4bb5a2b3747a9166bab52fec4b5b5e7d/rpds_py-0.30.0-cp314-cp314t-win32.whl", hash = "sha256:dc824125c72246d924f7f796b4f63c1e9dc810c7d9e2355864b3c3a73d59ade0", size = 211570, upload-time = "2025-11-30T20:24:12.735Z" }, + { url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" }, +] + [[package]] name = "rsa" version = "4.9.1" @@ -1183,6 +1315,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/9f/c3695c2d2d4ef70072c3a06992850498b01c6bc9be531950813716b426fa/sse_starlette-3.3.2.tar.gz", hash = "sha256:678fca55a1945c734d8472a6cad186a55ab02840b4f6786f5ee8770970579dcd", size = 32326, upload-time = "2026-02-28T11:24:34.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/28/8cb142d3fe80c4a2d8af54ca0b003f47ce0ba920974e7990fa6e016402d1/sse_starlette-3.3.2-py3-none-any.whl", hash = "sha256:5c3ea3dad425c601236726af2f27689b74494643f57017cafcb6f8c9acfbb862", size = 14270, upload-time = "2026-02-28T11:24:32.984Z" }, +] + [[package]] name = "starlette" version = "0.52.1" diff --git a/web/src/api/agent.ts b/web/src/api/agent.ts index f468a04..f1d683e 100644 --- a/web/src/api/agent.ts +++ b/web/src/api/agent.ts @@ -9,12 +9,18 @@ export interface AgentChatContext { export interface AgentChatRequest { messages: AgentChatMessage[]; auto_execute?: boolean; - approved_tool_call_ids?: string[]; - rejected_tool_call_ids?: string[]; + approved_mcp_call_ids?: string[]; + rejected_mcp_call_ids?: string[]; context?: AgentChatContext; } -export interface PendingToolCall { +export interface McpCall { + id: string; + name: string; + arguments: Record; +} + +export interface PendingMcpCall { id: string; name: string; arguments: Record; @@ -23,16 +29,16 @@ export interface PendingToolCall { export interface AgentChatResponse { messages: AgentChatMessage[]; - pending_tool_calls?: PendingToolCall[]; + pending_mcp_calls?: PendingMcpCall[]; } export type AgentSseEvent = | { event: 'assistant_start'; data: { id: string } } | { event: 'assistant_delta'; data: { id: string; delta: string } } | { event: 'assistant_end'; data: { id: string; message: AgentChatMessage } } - | { event: 'tool_start'; data: { tool_call_id: string; name: string } } - | { event: 'tool_end'; data: { tool_call_id: string; name: string; message: AgentChatMessage } } - | { event: 'pending'; data: { pending_tool_calls: PendingToolCall[] } } + | { event: 'mcp_call_start'; data: { mcp_call_id: string; name: string } } + | { event: 'mcp_call_end'; data: { mcp_call_id: string; name: string; message: AgentChatMessage } } + | { event: 'pending'; data: { pending_mcp_calls: PendingMcpCall[] } } | { event: 'done'; data: AgentChatResponse }; export const agentApi = { diff --git a/web/src/components/AiAgentWidget.tsx b/web/src/components/AiAgentWidget.tsx index 6f051ac..c4d7cfe 100644 --- a/web/src/components/AiAgentWidget.tsx +++ b/web/src/components/AiAgentWidget.tsx @@ -3,7 +3,7 @@ import { Avatar, Button, Divider, Flex, Input, List, Modal, Space, Switch, Tag, import { RobotOutlined, SendOutlined, DeleteOutlined, ToolOutlined, DownOutlined, UpOutlined, CodeOutlined, CopyOutlined, LoadingOutlined } from '@ant-design/icons'; import ReactMarkdown from 'react-markdown'; import type { TextAreaRef } from 'antd/es/input/TextArea'; -import { agentApi, type AgentChatMessage, type PendingToolCall } from '../api/agent'; +import { agentApi, type AgentChatMessage, type PendingMcpCall } from '../api/agent'; import { useI18n } from '../i18n'; import '../styles/ai-agent.css'; @@ -108,7 +108,7 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha const [input, setInput] = useState(''); const [loading, setLoading] = useState(false); const [messages, setMessages] = useState([]); - const [pending, setPending] = useState([]); + const [pending, setPending] = useState([]); const [expandedTools, setExpandedTools] = useState>({}); const [expandedRaw, setExpandedRaw] = useState>({}); const [runningTools, setRunningTools] = useState>({}); @@ -153,16 +153,14 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha for (const msg of messages) { if (!msg || typeof msg !== 'object') continue; if (msg.role !== 'assistant') continue; - const toolCalls = (msg as any).tool_calls; + const toolCalls = (msg as any).mcp_calls; if (!Array.isArray(toolCalls)) continue; for (const call of toolCalls) { const id = typeof call?.id === 'string' ? call.id : ''; - const fn = call?.function; - const name = typeof fn?.name === 'string' ? fn.name : ''; - const rawArgs = typeof fn?.arguments === 'string' ? fn.arguments : ''; + const name = typeof call?.name === 'string' ? call.name : ''; + const args = isPlainObject(call?.arguments) ? call.arguments : {}; if (!id || !name) continue; - const parsedArgs = tryParseJson>(rawArgs) || {}; - map.set(id, { name, args: parsedArgs }); + map.set(id, { name, args }); } } return map; @@ -179,7 +177,7 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha assistantIndexRef.current = {}; setLoading(true); - const approvedIds = payload.approved_tool_call_ids || []; + const approvedIds = payload.approved_mcp_call_ids || []; if (Array.isArray(approvedIds) && approvedIds.length > 0) { const preRunning: Record = {}; approvedIds.forEach((id) => { @@ -196,8 +194,8 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha messages: payload.messages, auto_execute: autoExecute, context: effectivePath ? { current_path: effectivePath } : undefined, - approved_tool_call_ids: payload.approved_tool_call_ids, - rejected_tool_call_ids: payload.rejected_tool_call_ids, + approved_mcp_call_ids: payload.approved_mcp_call_ids, + rejected_mcp_call_ids: payload.rejected_mcp_call_ids, }, (evt) => { if (seq !== streamSeqRef.current) return; @@ -241,16 +239,16 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha delete assistantIndexRef.current[id]; return; } - case 'tool_start': { - const toolCallId = String((evt.data as any)?.tool_call_id || ''); + case 'mcp_call_start': { + const toolCallId = String((evt.data as any)?.mcp_call_id || ''); const name = String((evt.data as any)?.name || ''); if (!toolCallId) return; if (name) toolNameByIdRef.current[toolCallId] = name; setRunningTools((prev) => ({ ...prev, [toolCallId]: name || prev[toolCallId] || '' })); return; } - case 'tool_end': { - const toolCallId = String((evt.data as any)?.tool_call_id || ''); + case 'mcp_call_end': { + const toolCallId = String((evt.data as any)?.mcp_call_id || ''); const name = String((evt.data as any)?.name || ''); const msg = (evt.data as any)?.message; if (toolCallId && name) toolNameByIdRef.current[toolCallId] = name; @@ -267,14 +265,14 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha return; } case 'pending': { - const items = Array.isArray((evt.data as any)?.pending_tool_calls) ? (evt.data as any).pending_tool_calls : []; + const items = Array.isArray((evt.data as any)?.pending_mcp_calls) ? (evt.data as any).pending_mcp_calls : []; setPending(items); return; } case 'done': { const base = baseMessagesRef.current || []; const newMessages = Array.isArray((evt.data as any)?.messages) ? (evt.data as any).messages : []; - const nextPending = Array.isArray((evt.data as any)?.pending_tool_calls) ? (evt.data as any).pending_tool_calls : []; + const nextPending = Array.isArray((evt.data as any)?.pending_mcp_calls) ? (evt.data as any).pending_mcp_calls : []; setMessages([...base, ...newMessages]); setPending(nextPending); setRunningTools({}); @@ -326,23 +324,23 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha }, []); const approveOne = useCallback(async (id: string) => { - await runStream({ messages, approved_tool_call_ids: [id] }); + await runStream({ messages, approved_mcp_call_ids: [id] }); }, [messages, runStream]); const rejectOne = useCallback(async (id: string) => { - await runStream({ messages, rejected_tool_call_ids: [id] }); + await runStream({ messages, rejected_mcp_call_ids: [id] }); }, [messages, runStream]); const approveAll = useCallback(async () => { const ids = pending.map((p) => p.id).filter(Boolean); if (ids.length === 0) return; - await runStream({ messages, approved_tool_call_ids: ids }); + await runStream({ messages, approved_mcp_call_ids: ids }); }, [messages, pending, runStream]); const rejectAll = useCallback(async () => { const ids = pending.map((p) => p.id).filter(Boolean); if (ids.length === 0) return; - await runStream({ messages, rejected_tool_call_ids: ids }); + await runStream({ messages, rejected_mcp_call_ids: ids }); }, [messages, pending, runStream]); const messageItems = useMemo(() => { @@ -665,7 +663,7 @@ const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenCha const role = String((m as any).role); const isUser = role === 'user'; const isTool = role === 'tool'; - const toolCallId = typeof (m as any).tool_call_id === 'string' ? String((m as any).tool_call_id) : ''; + const toolCallId = typeof (m as any).mcp_call_id === 'string' ? String((m as any).mcp_call_id) : ''; const toolInfo = toolCallId ? toolCallsById.get(toolCallId) : null; const toolName = toolInfo?.name || (toolCallId ? toolNameByIdRef.current[toolCallId] : '') || ''; const msgKey = toolCallId ? `tool:${toolCallId}` : `${role}:${idx}`;