From a727e77341cc1ef045da00a4a0d375aa4f35a13e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=97=B6=E9=9B=A8?= Date: Fri, 9 Jan 2026 16:19:20 +0800 Subject: [PATCH] feat: Implement AI Agent with enhanced tool processing capabilities (#89) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Implement AI Agent with tool processing capabilities - Added tools for listing and running processors in the agent. - Created data models for agent chat requests and tool calls. - Developed API integration for agent chat and streaming responses. - Built the AI Agent widget with a user interface for interaction. - Styled the agent components for better user experience. * feat: 增强 AI 助手工具功能,添加文件操作和搜索功能,更新界面显示 * feat: 更新 AI 助手组件 * feat: 更新 AiAgentWidget 组件样式,调整背景和边距以提升界面一致性 --- api/routers.py | 2 + domain/agent/__init__.py | 4 + domain/agent/api.py | 39 ++ domain/agent/service.py | 448 +++++++++++++ domain/agent/tools.py | 413 ++++++++++++ domain/agent/types.py | 23 + domain/ai/inference.py | 196 +++++- web/src/api/agent.ts | 121 ++++ web/src/components/AiAgentWidget.tsx | 932 +++++++++++++++++++++++++++ web/src/i18n/locales/en.json | 37 +- web/src/i18n/locales/zh.json | 37 +- web/src/layout/TopHeader.tsx | 16 +- web/src/router/LayoutShell.tsx | 6 +- web/src/styles/ai-agent.css | 244 +++++++ 14 files changed, 2511 insertions(+), 7 deletions(-) create mode 100644 domain/agent/__init__.py create mode 100644 domain/agent/api.py create mode 100644 domain/agent/service.py create mode 100644 domain/agent/tools.py create mode 100644 domain/agent/types.py create mode 100644 web/src/api/agent.ts create mode 100644 web/src/components/AiAgentWidget.tsx create mode 100644 web/src/styles/ai-agent.css diff --git a/api/routers.py b/api/routers.py index 5fa3d51..8242400 100644 --- a/api/routers.py +++ b/api/routers.py @@ -11,6 +11,7 @@ from domain.processors import api as processors from domain.share import api as share from domain.tasks import api as tasks from domain.ai import api as ai +from domain.agent import api as agent from domain.virtual_fs import api as virtual_fs from domain.virtual_fs.mapping import s3_api, webdav_api from domain.virtual_fs.search import search_api @@ -30,6 +31,7 @@ def include_routers(app: FastAPI): app.include_router(backup.router) app.include_router(ai.router_vector_db) app.include_router(ai.router_ai) + app.include_router(agent.router) app.include_router(plugins.router) app.include_router(webdav_api.router) app.include_router(s3_api.router) diff --git a/domain/agent/__init__.py b/domain/agent/__init__.py new file mode 100644 index 0000000..703a4df --- /dev/null +++ b/domain/agent/__init__.py @@ -0,0 +1,4 @@ +from .api import router + +__all__ = ["router"] + diff --git a/domain/agent/api.py b/domain/agent/api.py new file mode 100644 index 0000000..b87dacd --- /dev/null +++ b/domain/agent/api.py @@ -0,0 +1,39 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import StreamingResponse + +from api.response import success +from domain.agent.service import AgentService +from domain.agent.types import AgentChatRequest +from domain.audit import AuditAction, audit +from domain.auth.service import get_current_active_user +from domain.auth.types import User + + +router = APIRouter(prefix="/api/agent", tags=["agent"]) + + +@router.post("/chat") +@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"]) +async def chat( + request: Request, + payload: AgentChatRequest, + current_user: Annotated[User, Depends(get_current_active_user)], +): + data = await AgentService.chat(payload, current_user) + return success(data) + + +@router.post("/chat/stream") +@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute"]) +async def chat_stream( + request: Request, + payload: AgentChatRequest, + current_user: Annotated[User, Depends(get_current_active_user)], +): + return StreamingResponse( + AgentService.chat_stream(payload, current_user), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache"}, + ) diff --git a/domain/agent/service.py b/domain/agent/service.py new file mode 100644 index 0000000..30b3ce5 --- /dev/null +++ b/domain/agent/service.py @@ -0,0 +1,448 @@ +import asyncio +import json +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import httpx +from fastapi import HTTPException + +from domain.agent.tools import get_tool, openai_tools, tool_result_to_content +from domain.agent.types import AgentChatRequest, PendingToolCall +from domain.ai.inference import MissingModelError, chat_completion, chat_completion_stream +from domain.ai.service import AIProviderService +from domain.auth.types import User + + +def _normalize_path(p: Optional[str]) -> Optional[str]: + if not p: + return None + s = str(p).strip() + if not s: + return None + s = s.replace("\\", "/") + if not s.startswith("/"): + s = "/" + s + s = s.rstrip("/") or "/" + return s + + +def _build_system_prompt(current_path: Optional[str]) -> str: + lines = [ + "你是 Foxel 的 AI 助手。", + "你可以通过工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。", + "", + "可用工具:", + "- vfs_list_dir:浏览目录(列出 entries + pagination)。", + "- vfs_stat:查看文件/目录信息。", + "- vfs_read_text:读取文本文件内容(不支持二进制)。", + "- vfs_search:搜索文件(vector/filename)。", + "- vfs_write_text:写入文本文件内容(覆盖)。", + "- vfs_mkdir:创建目录。", + "- vfs_delete:删除文件或目录。", + "- vfs_move:移动路径。", + "- vfs_copy:复制路径。", + "- vfs_rename:重命名路径。", + "- processors_list:获取可用处理器列表(含 type/name/config_schema/produces_file/supports_directory)。", + "- processors_run:运行处理器处理文件或目录(会返回 task_id 或 task_ids)。", + "", + "规则:", + "1) 读操作(vfs_list_dir/vfs_stat/vfs_read_text/vfs_search)可直接调用工具。", + "2) 写/改/删操作(vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run)默认需要用户确认;只有在开启自动执行时才应直接执行。", + "3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。", + "4) 修改文件内容:先读取(vfs_read_text)→给出改动点→确认后再写入(vfs_write_text)。", + "5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。", + "6) 回答保持简洁中文。", + ] + if current_path: + lines.append("") + lines.append(f"当前文件管理目录:{current_path}") + return "\n".join(lines) + + +def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]: + tool_calls = message.get("tool_calls") + if not isinstance(tool_calls, list): + return message + + changed = False + for idx, call in enumerate(tool_calls): + if not isinstance(call, dict): + continue + call_id = call.get("id") + if isinstance(call_id, str) and call_id.strip(): + continue + call["id"] = f"call_{idx}" + changed = True + + if changed: + message["tool_calls"] = tool_calls + return message + + +def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall: + call_id = str(tool_call.get("id") or "") + fn = tool_call.get("function") or {} + name = str((fn.get("name") if isinstance(fn, dict) else None) or "") + raw_args = fn.get("arguments") if isinstance(fn, dict) else None + arguments: Dict[str, Any] = {} + if isinstance(raw_args, str) and raw_args.strip(): + try: + parsed = json.loads(raw_args) + if isinstance(parsed, dict): + arguments = parsed + except json.JSONDecodeError: + arguments = {} + return PendingToolCall( + id=call_id, + name=name, + arguments=arguments, + requires_confirmation=requires_confirmation, + ) + + +def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]: + for idx in range(len(messages) - 1, -1, -1): + msg = messages[idx] + if not isinstance(msg, dict): + continue + if msg.get("role") != "assistant": + continue + tool_calls = msg.get("tool_calls") + if isinstance(tool_calls, list) and tool_calls: + return idx, msg + raise HTTPException(status_code=400, detail="没有可确认的待执行操作") + + +def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]: + ids: set[str] = set() + for msg in messages: + if not isinstance(msg, dict): + continue + if msg.get("role") != "tool": + continue + tool_call_id = msg.get("tool_call_id") + if isinstance(tool_call_id, str) and tool_call_id.strip(): + ids.add(tool_call_id) + return ids + + +async def _choose_chat_ability() -> str: + tools_model = await AIProviderService.get_default_model("tools") + return "tools" if tools_model else "chat" + + +def _sse(event: str, data: Any) -> bytes: + payload = json.dumps(data, ensure_ascii=False, separators=(",", ":")) + return f"event: {event}\ndata: {payload}\n\n".encode("utf-8") + + +class AgentService: + @classmethod + async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]: + history: List[Dict[str, Any]] = list(req.messages or []) + current_path = _normalize_path(req.context.current_path if req.context else None) + + system_prompt = _build_system_prompt(current_path) + internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history + + new_messages: List[Dict[str, Any]] = [] + pending: List[PendingToolCall] = [] + + approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()} + rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()} + + if approved_ids or rejected_ids: + _, last_call_msg = _find_last_assistant_tool_calls(internal_messages) + last_call_msg = _ensure_tool_call_ids(last_call_msg) + tool_calls = last_call_msg.get("tool_calls") or [] + call_map: Dict[str, Dict[str, Any]] = { + str(c.get("id")): c + for c in tool_calls + if isinstance(c, dict) and isinstance(c.get("id"), str) + } + + existing_ids = _existing_tool_result_ids(internal_messages) + for call_id in approved_ids | rejected_ids: + if call_id in existing_ids: + continue + tool_call = call_map.get(call_id) + if not tool_call: + continue + fn = tool_call.get("function") or {} + name = fn.get("name") if isinstance(fn, dict) else None + args_raw = fn.get("arguments") if isinstance(fn, dict) else None + args: Dict[str, Any] = {} + if isinstance(args_raw, str) and args_raw.strip(): + try: + parsed = json.loads(args_raw) + if isinstance(parsed, dict): + args = parsed + except json.JSONDecodeError: + args = {} + + spec = get_tool(str(name or "")) + if call_id in rejected_ids: + content = tool_result_to_content({"canceled": True, "reason": "user_rejected"}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + continue + + if not spec: + content = tool_result_to_content({"error": f"unknown_tool: {name}"}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + continue + + try: + result = await spec.handler(args) + content = tool_result_to_content(result) + except Exception as exc: # noqa: BLE001 + content = tool_result_to_content({"error": str(exc)}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + + tools_schema = openai_tools() + ability = await _choose_chat_ability() + max_loops = 4 + + for _ in range(max_loops): + try: + assistant = await chat_completion( + internal_messages, + ability=ability, + tools=tools_schema, + tool_choice="auto", + timeout=60.0, + ) + except MissingModelError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except httpx.HTTPStatusError as exc: + raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc + except httpx.RequestError as exc: + raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc + + assistant = _ensure_tool_call_ids(assistant) + internal_messages.append(assistant) + new_messages.append(assistant) + + tool_calls = assistant.get("tool_calls") + if not isinstance(tool_calls, list) or not tool_calls: + break + + pending = [] + for call in tool_calls: + if not isinstance(call, dict): + continue + call_id = str(call.get("id") or "") + fn = call.get("function") or {} + name = fn.get("name") if isinstance(fn, dict) else None + args_raw = fn.get("arguments") if isinstance(fn, dict) else None + args: Dict[str, Any] = {} + if isinstance(args_raw, str) and args_raw.strip(): + try: + parsed = json.loads(args_raw) + if isinstance(parsed, dict): + args = parsed + except json.JSONDecodeError: + args = {} + + spec = get_tool(str(name or "")) + if not spec: + content = tool_result_to_content({"error": f"unknown_tool: {name}"}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + continue + + if spec.requires_confirmation and not req.auto_execute: + pending.append(_extract_pending(call, True)) + continue + + try: + result = await spec.handler(args) + content = tool_result_to_content(result) + except Exception as exc: # noqa: BLE001 + content = tool_result_to_content({"error": str(exc)}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + + if pending: + break + + payload: Dict[str, Any] = {"messages": new_messages} + if pending: + payload["pending_tool_calls"] = [p.model_dump() for p in pending] + return payload + + @classmethod + async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]): + history: List[Dict[str, Any]] = list(req.messages or []) + current_path = _normalize_path(req.context.current_path if req.context else None) + + system_prompt = _build_system_prompt(current_path) + internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history + + new_messages: List[Dict[str, Any]] = [] + pending: List[PendingToolCall] = [] + + approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()} + rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()} + + try: + if approved_ids or rejected_ids: + _, last_call_msg = _find_last_assistant_tool_calls(internal_messages) + last_call_msg = _ensure_tool_call_ids(last_call_msg) + tool_calls = last_call_msg.get("tool_calls") or [] + call_map: Dict[str, Dict[str, Any]] = { + str(c.get("id")): c + for c in tool_calls + if isinstance(c, dict) and isinstance(c.get("id"), str) + } + + existing_ids = _existing_tool_result_ids(internal_messages) + for call_id in approved_ids | rejected_ids: + if call_id in existing_ids: + continue + tool_call = call_map.get(call_id) + if not tool_call: + continue + fn = tool_call.get("function") or {} + name = fn.get("name") if isinstance(fn, dict) else None + args_raw = fn.get("arguments") if isinstance(fn, dict) else None + args: Dict[str, Any] = {} + if isinstance(args_raw, str) and args_raw.strip(): + try: + parsed = json.loads(args_raw) + if isinstance(parsed, dict): + args = parsed + except json.JSONDecodeError: + args = {} + + spec = get_tool(str(name or "")) + if call_id in rejected_ids: + content = tool_result_to_content({"canceled": True, "reason": "user_rejected"}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg}) + continue + + if not spec: + content = tool_result_to_content({"error": f"unknown_tool: {name}"}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg}) + continue + + yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name}) + try: + result = await spec.handler(args) + content = tool_result_to_content(result) + except Exception as exc: # noqa: BLE001 + content = tool_result_to_content({"error": str(exc)}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg}) + + tools_schema = openai_tools() + ability = await _choose_chat_ability() + max_loops = 4 + + for _ in range(max_loops): + assistant_event_id = uuid.uuid4().hex + yield _sse("assistant_start", {"id": assistant_event_id}) + + assistant_message: Dict[str, Any] | None = None + try: + async for event in chat_completion_stream( + internal_messages, + ability=ability, + tools=tools_schema, + tool_choice="auto", + timeout=60.0, + ): + if event.get("type") == "delta": + delta = event.get("delta") + if isinstance(delta, str) and delta: + yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta}) + elif event.get("type") == "message": + msg = event.get("message") + if isinstance(msg, dict): + assistant_message = msg + except MissingModelError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except httpx.HTTPStatusError as exc: + raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc + except httpx.RequestError as exc: + raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc + + if not assistant_message: + assistant_message = {"role": "assistant", "content": ""} + + assistant_message = _ensure_tool_call_ids(assistant_message) + internal_messages.append(assistant_message) + new_messages.append(assistant_message) + yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message}) + + tool_calls = assistant_message.get("tool_calls") + if not isinstance(tool_calls, list) or not tool_calls: + break + + pending = [] + for call in tool_calls: + if not isinstance(call, dict): + continue + call_id = str(call.get("id") or "") + fn = call.get("function") or {} + name = fn.get("name") if isinstance(fn, dict) else None + args_raw = fn.get("arguments") if isinstance(fn, dict) else None + args: Dict[str, Any] = {} + if isinstance(args_raw, str) and args_raw.strip(): + try: + parsed = json.loads(args_raw) + if isinstance(parsed, dict): + args = parsed + except json.JSONDecodeError: + args = {} + + spec = get_tool(str(name or "")) + if not spec: + content = tool_result_to_content({"error": f"unknown_tool: {name}"}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg}) + continue + + if spec.requires_confirmation and not req.auto_execute: + pending.append(_extract_pending(call, True)) + continue + + yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name}) + try: + result = await spec.handler(args) + content = tool_result_to_content(result) + except Exception as exc: # noqa: BLE001 + content = tool_result_to_content({"error": str(exc)}) + tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content} + internal_messages.append(tool_msg) + new_messages.append(tool_msg) + yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg}) + + if pending: + yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]}) + break + + payload: Dict[str, Any] = {"messages": new_messages} + if pending: + payload["pending_tool_calls"] = [p.model_dump() for p in pending] + yield _sse("done", payload) + + except asyncio.CancelledError: + return diff --git a/domain/agent/tools.py b/domain/agent/tools.py new file mode 100644 index 0000000..53f060c --- /dev/null +++ b/domain/agent/tools.py @@ -0,0 +1,413 @@ +import json +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, List, Optional + +from domain.processors.service import ProcessorService +from domain.processors.types import ProcessDirectoryRequest, ProcessRequest +from domain.virtual_fs.service import VirtualFSService +from domain.virtual_fs.search.search_service import VirtualFSSearchService + + +@dataclass(frozen=True) +class ToolSpec: + name: str + description: str + parameters: Dict[str, Any] + requires_confirmation: bool + handler: Callable[[Dict[str, Any]], Awaitable[Any]] + + +async def _processors_list(_: Dict[str, Any]) -> Dict[str, Any]: + return {"processors": ProcessorService.list_processors()} + + +async def _processors_run(args: Dict[str, Any]) -> Dict[str, Any]: + path = str(args.get("path") or "") + processor_type = str(args.get("processor_type") or "") + config = args.get("config") + if not isinstance(config, dict): + config = {} + + save_to = args.get("save_to") + save_to = str(save_to) if isinstance(save_to, str) and save_to.strip() else None + + max_depth = args.get("max_depth") + max_depth_value: Optional[int] = None + if max_depth is not None: + try: + max_depth_value = int(max_depth) + except (TypeError, ValueError): + max_depth_value = None + + suffix = args.get("suffix") + suffix_value = str(suffix) if isinstance(suffix, str) and suffix.strip() else None + + overwrite_value = args.get("overwrite") + overwrite = bool(overwrite_value) if overwrite_value is not None else None + + is_dir = await VirtualFSService.path_is_directory(path) + if is_dir and (max_depth_value is not None or suffix_value is not None): + req = ProcessDirectoryRequest( + path=path, + processor_type=processor_type, + config=config, + overwrite=True if overwrite is None else overwrite, + max_depth=max_depth_value, + suffix=suffix_value, + ) + result = await ProcessorService.process_directory(req) + return {"mode": "directory", **result} + + req = ProcessRequest( + path=path, + processor_type=processor_type, + config=config, + save_to=save_to, + overwrite=False if overwrite is None else overwrite, + ) + result = await ProcessorService.process_file(req) + return {"mode": "file", **result} + + +def _normalize_vfs_path(value: Any) -> str: + s = str(value or "").strip().replace("\\", "/") + if not s: + return "" + if not s.startswith("/"): + s = "/" + s + s = s.rstrip("/") or "/" + return s + + +def _require_vfs_path(value: Any, field: str) -> str: + path = _normalize_vfs_path(value) + if not path: + raise ValueError(f"missing_{field}") + return path + + +async def _vfs_list_dir(args: Dict[str, Any]) -> Dict[str, Any]: + path = _normalize_vfs_path(args.get("path") or "/") or "/" + page = int(args.get("page") or 1) + page_size = int(args.get("page_size") or 50) + sort_by = str(args.get("sort_by") or "name") + sort_order = str(args.get("sort_order") or "asc") + return await VirtualFSService.list_directory(path, page, page_size, sort_by, sort_order) + + +async def _vfs_stat(args: Dict[str, Any]) -> Any: + path = _require_vfs_path(args.get("path"), "path") + return await VirtualFSService.stat(path) + + +async def _vfs_read_text(args: Dict[str, Any]) -> Dict[str, Any]: + path = _require_vfs_path(args.get("path"), "path") + encoding = str(args.get("encoding") or "utf-8") + max_chars = int(args.get("max_chars") or 8000) + + data = await VirtualFSService.read_file(path) + if isinstance(data, (bytes, bytearray)): + try: + text = bytes(data).decode(encoding) + except UnicodeDecodeError: + return {"error": "binary_or_invalid_text", "path": path} + elif isinstance(data, str): + text = data + else: + text = str(data) + + original_len = len(text) + truncated = original_len > max_chars + if truncated: + text = text[:max_chars] + return { + "path": path, + "encoding": encoding, + "content": text, + "truncated": truncated, + "length": original_len, + } + + +async def _vfs_write_text(args: Dict[str, Any]) -> Dict[str, Any]: + path = _require_vfs_path(args.get("path"), "path") + if path == "/": + raise ValueError("invalid_path") + encoding = str(args.get("encoding") or "utf-8") + content = str(args.get("content") or "") + data = content.encode(encoding) + await VirtualFSService.write_file(path, data) + return {"written": True, "path": path, "encoding": encoding, "bytes": len(data)} + + +async def _vfs_mkdir(args: Dict[str, Any]) -> Dict[str, Any]: + path = _require_vfs_path(args.get("path"), "path") + return await VirtualFSService.mkdir(path) + + +async def _vfs_delete(args: Dict[str, Any]) -> Dict[str, Any]: + path = _require_vfs_path(args.get("path"), "path") + return await VirtualFSService.delete(path) + + +async def _vfs_move(args: Dict[str, Any]) -> Dict[str, Any]: + src = _require_vfs_path(args.get("src"), "src") + dst = _require_vfs_path(args.get("dst"), "dst") + if src == "/" or dst == "/": + raise ValueError("invalid_path") + overwrite = bool(args.get("overwrite") or False) + return await VirtualFSService.move(src, dst, overwrite) + + +async def _vfs_copy(args: Dict[str, Any]) -> Dict[str, Any]: + src = _require_vfs_path(args.get("src"), "src") + dst = _require_vfs_path(args.get("dst"), "dst") + if src == "/" or dst == "/": + raise ValueError("invalid_path") + overwrite = bool(args.get("overwrite") or False) + return await VirtualFSService.copy(src, dst, overwrite) + + +async def _vfs_rename(args: Dict[str, Any]) -> Dict[str, Any]: + src = _require_vfs_path(args.get("src"), "src") + dst = _require_vfs_path(args.get("dst"), "dst") + if src == "/" or dst == "/": + raise ValueError("invalid_path") + overwrite = bool(args.get("overwrite") or False) + return await VirtualFSService.rename(src, dst, overwrite) + + +async def _vfs_search(args: Dict[str, Any]) -> Dict[str, Any]: + q = str(args.get("q") or "").strip() + if not q: + raise ValueError("missing_q") + mode = str(args.get("mode") or "vector") + top_k = int(args.get("top_k") or 10) + page = int(args.get("page") or 1) + page_size = int(args.get("page_size") or 10) + return await VirtualFSSearchService.search(q, top_k, mode, page, page_size) + + +TOOLS: Dict[str, ToolSpec] = { + "processors_list": ToolSpec( + name="processors_list", + description="获取可用处理器列表(type/name/config_schema 等)。", + parameters={ + "type": "object", + "properties": {}, + "additionalProperties": False, + }, + requires_confirmation=False, + handler=_processors_list, + ), + "processors_run": ToolSpec( + name="processors_run", + description=( + "运行处理器处理文件或目录。" + " 对目录可选 max_depth/suffix;对文件可选 overwrite/save_to。" + " 返回任务 id(去任务队列查看进度)。" + ), + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "文件或目录路径(绝对路径,如 /foo/bar)"}, + "processor_type": {"type": "string", "description": "处理器类型(例如 image_watermark)"}, + "config": {"type": "object", "description": "处理器配置,按 processors_list 返回的 config_schema 填写"}, + "overwrite": {"type": "boolean", "description": "是否覆盖原文件/目录内文件"}, + "save_to": {"type": "string", "description": "保存到指定路径(仅文件模式,且 overwrite=false 时使用)"}, + "max_depth": {"type": "integer", "description": "目录遍历深度(仅目录模式)"}, + "suffix": {"type": "string", "description": "目录批处理时的输出后缀(仅 produces_file 且 overwrite=false)"}, + }, + "required": ["path", "processor_type"], + }, + requires_confirmation=True, + handler=_processors_run, + ), + "vfs_list_dir": ToolSpec( + name="vfs_list_dir", + description="浏览目录(列出 entries + pagination)。", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"}, + "page": {"type": "integer", "description": "页码(从 1 开始)"}, + "page_size": {"type": "integer", "description": "每页条数"}, + "sort_by": {"type": "string", "description": "排序字段:name/size/mtime"}, + "sort_order": {"type": "string", "description": "排序顺序:asc/desc"}, + }, + "required": ["path"], + "additionalProperties": False, + }, + requires_confirmation=False, + handler=_vfs_list_dir, + ), + "vfs_stat": ToolSpec( + name="vfs_stat", + description="查看文件/目录信息(size/mtime/is_dir/has_thumbnail/vector_index 等)。", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar.txt)"}, + }, + "required": ["path"], + "additionalProperties": False, + }, + requires_confirmation=False, + handler=_vfs_stat, + ), + "vfs_read_text": ToolSpec( + name="vfs_read_text", + description="读取文本文件内容(解码失败视为二进制,返回 error)。", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"}, + "encoding": {"type": "string", "description": "文本编码(默认 utf-8)"}, + "max_chars": {"type": "integer", "description": "最多返回的字符数(默认 8000)"}, + }, + "required": ["path"], + "additionalProperties": False, + }, + requires_confirmation=False, + handler=_vfs_read_text, + ), + "vfs_write_text": ToolSpec( + name="vfs_write_text", + description="写入文本文件内容(会覆盖目标文件)。", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"}, + "content": {"type": "string", "description": "要写入的文本内容"}, + "encoding": {"type": "string", "description": "文本编码(默认 utf-8)"}, + }, + "required": ["path", "content"], + "additionalProperties": False, + }, + requires_confirmation=True, + handler=_vfs_write_text, + ), + "vfs_mkdir": ToolSpec( + name="vfs_mkdir", + description="创建目录。", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"}, + }, + "required": ["path"], + "additionalProperties": False, + }, + requires_confirmation=True, + handler=_vfs_mkdir, + ), + "vfs_delete": ToolSpec( + name="vfs_delete", + description="删除文件或目录(由底层适配器决定是否递归)。", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar 或 /foo/bar.txt)"}, + }, + "required": ["path"], + "additionalProperties": False, + }, + requires_confirmation=True, + handler=_vfs_delete, + ), + "vfs_move": ToolSpec( + name="vfs_move", + description="移动路径(可能进入任务队列)。", + parameters={ + "type": "object", + "properties": { + "src": {"type": "string", "description": "源路径(绝对路径)"}, + "dst": {"type": "string", "description": "目标路径(绝对路径)"}, + "overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"}, + }, + "required": ["src", "dst"], + "additionalProperties": False, + }, + requires_confirmation=True, + handler=_vfs_move, + ), + "vfs_copy": ToolSpec( + name="vfs_copy", + description="复制路径(可能进入任务队列)。", + parameters={ + "type": "object", + "properties": { + "src": {"type": "string", "description": "源路径(绝对路径)"}, + "dst": {"type": "string", "description": "目标路径(绝对路径)"}, + "overwrite": {"type": "boolean", "description": "是否覆盖已存在目标(默认 false)"}, + }, + "required": ["src", "dst"], + "additionalProperties": False, + }, + requires_confirmation=True, + handler=_vfs_copy, + ), + "vfs_rename": ToolSpec( + name="vfs_rename", + description="重命名路径(本质是同目录 move)。", + parameters={ + "type": "object", + "properties": { + "src": {"type": "string", "description": "源路径(绝对路径)"}, + "dst": {"type": "string", "description": "目标路径(绝对路径)"}, + "overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"}, + }, + "required": ["src", "dst"], + "additionalProperties": False, + }, + requires_confirmation=True, + handler=_vfs_rename, + ), + "vfs_search": ToolSpec( + name="vfs_search", + description="搜索文件(mode=vector 或 filename)。", + parameters={ + "type": "object", + "properties": { + "q": {"type": "string", "description": "搜索关键词"}, + "mode": {"type": "string", "description": "搜索模式:vector/filename(默认 vector)"}, + "top_k": {"type": "integer", "description": "返回数量(vector 模式使用,默认 10)"}, + "page": {"type": "integer", "description": "页码(filename 模式使用,默认 1)"}, + "page_size": {"type": "integer", "description": "分页大小(filename 模式使用,默认 10)"}, + }, + "required": ["q"], + "additionalProperties": False, + }, + requires_confirmation=False, + handler=_vfs_search, + ), +} + + +def get_tool(name: str) -> Optional[ToolSpec]: + return TOOLS.get(name) + + +def openai_tools() -> List[Dict[str, Any]]: + out: List[Dict[str, Any]] = [] + for spec in TOOLS.values(): + out.append({ + "type": "function", + "function": { + "name": spec.name, + "description": spec.description, + "parameters": spec.parameters, + }, + }) + return out + + +def tool_result_to_content(result: Any) -> str: + if result is None: + return "" + if isinstance(result, str): + return result + try: + return json.dumps(result, ensure_ascii=False) + except TypeError: + return json.dumps({"result": str(result)}, ensure_ascii=False) diff --git a/domain/agent/types.py b/domain/agent/types.py new file mode 100644 index 0000000..7513ba2 --- /dev/null +++ b/domain/agent/types.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class AgentChatContext(BaseModel): + current_path: Optional[str] = None + + +class AgentChatRequest(BaseModel): + messages: List[Dict[str, Any]] = Field(default_factory=list) + auto_execute: bool = False + approved_tool_call_ids: List[str] = Field(default_factory=list) + rejected_tool_call_ids: List[str] = Field(default_factory=list) + context: Optional[AgentChatContext] = None + + +class PendingToolCall(BaseModel): + id: str + name: str + arguments: Dict[str, Any] = Field(default_factory=dict) + requires_confirmation: bool = True + diff --git a/domain/ai/inference.py b/domain/ai/inference.py index 8be093c..4399eb9 100644 --- a/domain/ai/inference.py +++ b/domain/ai/inference.py @@ -1,5 +1,7 @@ +import json + import httpx -from typing import List, Sequence, Tuple +from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple from models.database import AIModel, AIProvider from domain.ai.service import AIProviderService @@ -243,3 +245,195 @@ async def _rerank_with_gemini( except (TypeError, ValueError): scores.append(0.0) return scores + + +async def chat_completion( + messages: List[Dict[str, Any]], + *, + ability: str = "chat", + tools: List[Dict[str, Any]] | None = None, + tool_choice: Any | None = None, + temperature: float | None = None, + timeout: float = 60.0, +) -> Dict[str, Any]: + model, provider = await _require_model(ability) + if provider.api_format != "openai": + raise MissingModelError("当前仅支持 OpenAI 兼容接口的对话模型。") + return await _chat_with_openai( + provider, + model, + messages, + tools=tools, + tool_choice=tool_choice, + temperature=temperature, + timeout=timeout, + ) + + +async def _chat_with_openai( + provider: AIProvider, + model: AIModel, + messages: List[Dict[str, Any]], + *, + tools: List[Dict[str, Any]] | None, + tool_choice: Any | None, + temperature: float | None, + timeout: float, +) -> Dict[str, Any]: + url = _openai_endpoint(provider, "/chat/completions") + payload: Dict[str, Any] = { + "model": model.name, + "messages": messages, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = tool_choice or "auto" + if temperature is not None: + payload["temperature"] = float(temperature) + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(url, headers=_openai_headers(provider), json=payload) + response.raise_for_status() + body = response.json() + + choices = body.get("choices") or [] + if not choices: + raise RuntimeError("对话接口返回为空") + message = choices[0].get("message") + if not isinstance(message, dict): + raise RuntimeError("对话接口返回格式异常") + return message + + +async def chat_completion_stream( + messages: List[Dict[str, Any]], + *, + ability: str = "chat", + tools: List[Dict[str, Any]] | None = None, + tool_choice: Any | None = None, + temperature: float | None = None, + timeout: float = 60.0, +) -> AsyncIterator[Dict[str, Any]]: + model, provider = await _require_model(ability) + if provider.api_format != "openai": + raise MissingModelError("当前仅支持 OpenAI 兼容接口的对话模型。") + async for event in _chat_stream_with_openai( + provider, + model, + messages, + tools=tools, + tool_choice=tool_choice, + temperature=temperature, + timeout=timeout, + ): + yield event + + +async def _chat_stream_with_openai( + provider: AIProvider, + model: AIModel, + messages: List[Dict[str, Any]], + *, + tools: List[Dict[str, Any]] | None, + tool_choice: Any | None, + temperature: float | None, + timeout: float, +) -> AsyncIterator[Dict[str, Any]]: + url = _openai_endpoint(provider, "/chat/completions") + payload: Dict[str, Any] = { + "model": model.name, + "messages": messages, + "stream": True, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = tool_choice or "auto" + if temperature is not None: + payload["temperature"] = float(temperature) + + content_parts: List[str] = [] + tool_call_map: Dict[int, Dict[str, Any]] = {} + role = "assistant" + finish_reason: str | None = None + + async with httpx.AsyncClient(timeout=timeout) as client: + async with client.stream("POST", url, headers=_openai_headers(provider), json=payload) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line: + continue + if not line.startswith("data:"): + continue + data = line[5:].strip() + if not data: + continue + if data == "[DONE]": + break + try: + chunk = json.loads(data) + except json.JSONDecodeError: + continue + + choices = chunk.get("choices") or [] + if not choices: + continue + choice = choices[0] if isinstance(choices[0], dict) else {} + delta = choice.get("delta") if isinstance(choice, dict) else None + delta = delta if isinstance(delta, dict) else {} + + if isinstance(delta.get("role"), str): + role = delta["role"] + + delta_content = delta.get("content") + if isinstance(delta_content, str) and delta_content: + content_parts.append(delta_content) + yield {"type": "delta", "delta": delta_content} + + delta_tool_calls = delta.get("tool_calls") + if isinstance(delta_tool_calls, list): + for item in delta_tool_calls: + if not isinstance(item, dict): + continue + idx = item.get("index") + if not isinstance(idx, int): + continue + entry = tool_call_map.setdefault( + idx, + {"id": None, "type": None, "function": {"name": None, "arguments": ""}}, + ) + if isinstance(item.get("id"), str) and item["id"].strip(): + entry["id"] = item["id"] + if isinstance(item.get("type"), str) and item["type"].strip(): + entry["type"] = item["type"] + fn = item.get("function") + if isinstance(fn, dict): + if isinstance(fn.get("name"), str) and fn["name"].strip(): + entry["function"]["name"] = fn["name"] + args_part = fn.get("arguments") + if isinstance(args_part, str) and args_part: + entry["function"]["arguments"] += args_part + + fr = choice.get("finish_reason") if isinstance(choice, dict) else None + if isinstance(fr, str) and fr: + finish_reason = fr + + content = "".join(content_parts) + message: Dict[str, Any] = {"role": role, "content": content} + if tool_call_map: + tool_calls: List[Dict[str, Any]] = [] + for idx in sorted(tool_call_map.keys()): + item = tool_call_map[idx] + fn = item.get("function") if isinstance(item.get("function"), dict) else {} + call_id = item.get("id") if isinstance(item.get("id"), str) and item.get("id") else f"call_{idx}" + call_type = item.get("type") if isinstance(item.get("type"), str) and item.get("type") else "function" + tool_calls.append({ + "id": call_id, + "type": call_type, + "function": { + "name": fn.get("name") or "", + "arguments": fn.get("arguments") or "", + }, + }) + message["tool_calls"] = tool_calls + + yield {"type": "message", "message": message, "finish_reason": finish_reason} diff --git a/web/src/api/agent.ts b/web/src/api/agent.ts new file mode 100644 index 0000000..f468a04 --- /dev/null +++ b/web/src/api/agent.ts @@ -0,0 +1,121 @@ +import request, { API_BASE_URL } from './client'; + +export type AgentChatMessage = Record; + +export interface AgentChatContext { + current_path?: string | null; +} + +export interface AgentChatRequest { + messages: AgentChatMessage[]; + auto_execute?: boolean; + approved_tool_call_ids?: string[]; + rejected_tool_call_ids?: string[]; + context?: AgentChatContext; +} + +export interface PendingToolCall { + id: string; + name: string; + arguments: Record; + requires_confirmation: boolean; +} + +export interface AgentChatResponse { + messages: AgentChatMessage[]; + pending_tool_calls?: PendingToolCall[]; +} + +export type AgentSseEvent = + | { event: 'assistant_start'; data: { id: string } } + | { event: 'assistant_delta'; data: { id: string; delta: string } } + | { event: 'assistant_end'; data: { id: string; message: AgentChatMessage } } + | { event: 'tool_start'; data: { tool_call_id: string; name: string } } + | { event: 'tool_end'; data: { tool_call_id: string; name: string; message: AgentChatMessage } } + | { event: 'pending'; data: { pending_tool_calls: PendingToolCall[] } } + | { event: 'done'; data: AgentChatResponse }; + +export const agentApi = { + chat: (payload: AgentChatRequest) => + request('/agent/chat', { + method: 'POST', + json: payload, + }), + chatStream: async ( + payload: AgentChatRequest, + onEvent: (evt: AgentSseEvent) => void, + options?: { signal?: AbortSignal } + ) => { + const headers: Record = { + 'Content-Type': 'application/json', + 'Accept': 'text/event-stream', + }; + const token = localStorage.getItem('token'); + if (token) headers['Authorization'] = `Bearer ${token}`; + + const resp = await fetch(`${API_BASE_URL}/agent/chat/stream`, { + method: 'POST', + headers, + body: JSON.stringify(payload), + signal: options?.signal, + }); + + if (!resp.ok) { + let errMsg = resp.statusText; + try { + const data = await resp.json(); + if (Array.isArray((data as any)?.detail)) { + errMsg = (data as any).detail.map((e: any) => e.msg || JSON.stringify(e)).join('; '); + } else { + errMsg = (typeof (data as any)?.detail === 'string') ? (data as any).detail : JSON.stringify(data); + } + } catch { + try { + errMsg = await resp.text(); + } catch { void 0; } + } + throw new Error(errMsg || `Request failed: ${resp.status}`); + } + + const reader = resp.body?.getReader(); + if (!reader) throw new Error('Stream not supported'); + + const decoder = new TextDecoder(); + let buffer = ''; + + const flush = (raw: string) => { + const lines = raw.split(/\r?\n/); + let eventName = 'message'; + const dataLines: string[] = []; + for (const line of lines) { + if (line.startsWith('event:')) { + eventName = line.slice(6).trim(); + } else if (line.startsWith('data:')) { + dataLines.push(line.slice(5).trimStart()); + } + } + const dataStr = dataLines.join('\n').trim(); + if (!eventName || !dataStr) return; + try { + const data = JSON.parse(dataStr); + onEvent({ event: eventName as any, data } as any); + } catch { + // ignore parse error + } + }; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + while (true) { + const idx = buffer.indexOf('\n\n'); + if (idx === -1) break; + const chunk = buffer.slice(0, idx); + buffer = buffer.slice(idx + 2); + if (chunk.trim()) flush(chunk); + } + } + if (buffer.trim()) flush(buffer); + }, +}; diff --git a/web/src/components/AiAgentWidget.tsx b/web/src/components/AiAgentWidget.tsx new file mode 100644 index 0000000..fd05812 --- /dev/null +++ b/web/src/components/AiAgentWidget.tsx @@ -0,0 +1,932 @@ +import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { Avatar, Button, Divider, Drawer, Flex, Input, List, Space, Switch, Tag, Typography, message, theme } from 'antd'; +import { RobotOutlined, SendOutlined, FolderOpenOutlined, DeleteOutlined, ToolOutlined, DownOutlined, UpOutlined, CodeOutlined, CopyOutlined, LoadingOutlined } from '@ant-design/icons'; +import ReactMarkdown from 'react-markdown'; +import PathSelectorModal from './PathSelectorModal'; +import { agentApi, type AgentChatMessage, type PendingToolCall } from '../api/agent'; +import { useI18n } from '../i18n'; +import '../styles/ai-agent.css'; + +const { Text, Paragraph } = Typography; + +function normalizePath(p?: string | null): string | null { + if (!p) return null; + const s = ('/' + p).replace(/\/+/, '/').replace(/\\/g, '/').replace(/\/+$/, '') || '/'; + return s; +} + +function extractTextContent(content: any): string { + if (content == null) return ''; + if (typeof content === 'string') return content; + if (Array.isArray(content)) { + const parts: string[] = []; + for (const item of content) { + if (typeof item === 'string') { + if (item.trim()) parts.push(item); + continue; + } + const text = typeof item?.text === 'string' ? item.text : ''; + if (text.trim()) parts.push(text); + } + return parts.join('\n'); + } + try { + return JSON.stringify(content, null, 2); + } catch { + return String(content); + } +} + +function tryParseJson(raw: string): T | null { + if (typeof raw !== 'string') return null; + const s = raw.trim(); + if (!s) return null; + try { + return JSON.parse(s) as T; + } catch { + return null; + } +} + +function shortId(id: string, keep: number = 6): string { + const s = String(id || ''); + if (s.length <= keep * 2 + 3) return s; + return `${s.slice(0, keep)}…${s.slice(-keep)}`; +} + +interface AiAgentWidgetProps { + currentPath?: string | null; + open: boolean; + onOpenChange(open: boolean): void; +} + +const AiAgentWidget = memo(function AiAgentWidget({ currentPath, open, onOpenChange }: AiAgentWidgetProps) { + const { t } = useI18n(); + const { token } = theme.useToken(); + const [autoExecute, setAutoExecute] = useState(false); + const [input, setInput] = useState(''); + const [loading, setLoading] = useState(false); + const [messages, setMessages] = useState([]); + const [pending, setPending] = useState([]); + const [pathModalOpen, setPathModalOpen] = useState(false); + const [expandedTools, setExpandedTools] = useState>({}); + const [expandedRaw, setExpandedRaw] = useState>({}); + const [runningTools, setRunningTools] = useState>({}); + const scrollRef = useRef(null); + const streamControllerRef = useRef(null); + const streamSeqRef = useRef(0); + const baseMessagesRef = useRef([]); + const assistantIndexRef = useRef>({}); + const toolNameByIdRef = useRef>({}); + + const effectivePath = useMemo(() => normalizePath(currentPath), [currentPath]); + + const scrollToBottom = useCallback(() => { + const el = scrollRef.current; + if (!el) return; + el.scrollTop = el.scrollHeight; + }, []); + + useEffect(() => { + if (!open) return; + const t = window.setTimeout(scrollToBottom, 0); + return () => window.clearTimeout(t); + }, [messages, open, pending, scrollToBottom]); + + useEffect(() => { + return () => { + streamControllerRef.current?.abort(); + }; + }, []); + + const toolCallsById = useMemo(() => { + const map = new Map }>(); + for (const msg of messages) { + if (!msg || typeof msg !== 'object') continue; + if (msg.role !== 'assistant') continue; + const toolCalls = (msg as any).tool_calls; + if (!Array.isArray(toolCalls)) continue; + for (const call of toolCalls) { + const id = typeof call?.id === 'string' ? call.id : ''; + const fn = call?.function; + const name = typeof fn?.name === 'string' ? fn.name : ''; + const rawArgs = typeof fn?.arguments === 'string' ? fn.arguments : ''; + if (!id || !name) continue; + const parsedArgs = tryParseJson>(rawArgs) || {}; + map.set(id, { name, args: parsedArgs }); + } + } + return map; + }, [messages]); + + const runStream = useCallback(async (payload: Partial[0]> & { messages: AgentChatMessage[] }) => { + streamControllerRef.current?.abort(); + const controller = new AbortController(); + streamControllerRef.current = controller; + streamSeqRef.current += 1; + const seq = streamSeqRef.current; + + baseMessagesRef.current = payload.messages; + assistantIndexRef.current = {}; + + setLoading(true); + const approvedIds = payload.approved_tool_call_ids || []; + if (Array.isArray(approvedIds) && approvedIds.length > 0) { + const preRunning: Record = {}; + approvedIds.forEach((id) => { + if (typeof id === 'string' && id.trim()) preRunning[id] = ''; + }); + setRunningTools(preRunning); + } else { + setRunningTools({}); + } + + try { + await agentApi.chatStream( + { + messages: payload.messages, + auto_execute: autoExecute, + context: effectivePath ? { current_path: effectivePath } : undefined, + approved_tool_call_ids: payload.approved_tool_call_ids, + rejected_tool_call_ids: payload.rejected_tool_call_ids, + }, + (evt) => { + if (seq !== streamSeqRef.current) return; + switch (evt.event) { + case 'assistant_start': { + const id = String((evt.data as any)?.id || ''); + if (!id) return; + setMessages((prev) => { + const idx = prev.length; + assistantIndexRef.current[id] = idx; + return [...prev, { role: 'assistant', content: '' }]; + }); + return; + } + case 'assistant_delta': { + const id = String((evt.data as any)?.id || ''); + const delta = String((evt.data as any)?.delta || ''); + if (!id || !delta) return; + setMessages((prev) => { + const idx = assistantIndexRef.current[id]; + if (idx === undefined || idx < 0 || idx >= prev.length) return prev; + const cur = prev[idx] as any; + const curContent = typeof cur?.content === 'string' ? cur.content : extractTextContent(cur?.content); + const next = prev.slice(); + next[idx] = { ...cur, content: (curContent || '') + delta }; + return next; + }); + return; + } + case 'assistant_end': { + const id = String((evt.data as any)?.id || ''); + const msg = (evt.data as any)?.message; + if (!id || !msg || typeof msg !== 'object') return; + setMessages((prev) => { + const idx = assistantIndexRef.current[id]; + if (idx === undefined || idx < 0 || idx >= prev.length) return prev; + const next = prev.slice(); + next[idx] = msg; + return next; + }); + delete assistantIndexRef.current[id]; + return; + } + case 'tool_start': { + const toolCallId = String((evt.data as any)?.tool_call_id || ''); + const name = String((evt.data as any)?.name || ''); + if (!toolCallId) return; + if (name) toolNameByIdRef.current[toolCallId] = name; + setRunningTools((prev) => ({ ...prev, [toolCallId]: name || prev[toolCallId] || '' })); + return; + } + case 'tool_end': { + const toolCallId = String((evt.data as any)?.tool_call_id || ''); + const name = String((evt.data as any)?.name || ''); + const msg = (evt.data as any)?.message; + if (toolCallId && name) toolNameByIdRef.current[toolCallId] = name; + if (toolCallId) { + setRunningTools((prev) => { + const next = { ...prev }; + delete next[toolCallId]; + return next; + }); + } + if (msg && typeof msg === 'object') { + setMessages((prev) => [...prev, msg]); + } + return; + } + case 'pending': { + const items = Array.isArray((evt.data as any)?.pending_tool_calls) ? (evt.data as any).pending_tool_calls : []; + setPending(items); + return; + } + case 'done': { + const base = baseMessagesRef.current || []; + const newMessages = Array.isArray((evt.data as any)?.messages) ? (evt.data as any).messages : []; + const nextPending = Array.isArray((evt.data as any)?.pending_tool_calls) ? (evt.data as any).pending_tool_calls : []; + setMessages([...base, ...newMessages]); + setPending(nextPending); + setRunningTools({}); + assistantIndexRef.current = {}; + return; + } + default: + return; + } + }, + { signal: controller.signal } + ); + } catch (err: any) { + if (controller.signal.aborted) return; + message.error(err?.message || t('Operation failed')); + } finally { + if (seq === streamSeqRef.current) { + setLoading(false); + if (controller.signal.aborted) { + setRunningTools({}); + assistantIndexRef.current = {}; + } + } + } + }, [autoExecute, effectivePath, t]); + + const handleSend = useCallback(async () => { + const text = input.trim(); + if (!text) return; + if (pending.length > 0) { + message.warning(t('Please confirm pending actions first')); + return; + } + const nextUserMsg: AgentChatMessage = { role: 'user', content: text }; + setInput(''); + const base = [...messages, nextUserMsg]; + setMessages(base); + setPending([]); + await runStream({ messages: base }); + }, [input, messages, pending.length, runStream, t]); + + const clearChat = useCallback(() => { + streamControllerRef.current?.abort(); + setMessages([]); + setPending([]); + setExpandedTools({}); + setExpandedRaw({}); + setRunningTools({}); + }, []); + + const approveOne = useCallback(async (id: string) => { + await runStream({ messages, approved_tool_call_ids: [id] }); + }, [messages, runStream]); + + const rejectOne = useCallback(async (id: string) => { + await runStream({ messages, rejected_tool_call_ids: [id] }); + }, [messages, runStream]); + + const approveAll = useCallback(async () => { + const ids = pending.map((p) => p.id).filter(Boolean); + if (ids.length === 0) return; + await runStream({ messages, approved_tool_call_ids: ids }); + }, [messages, pending, runStream]); + + const rejectAll = useCallback(async () => { + const ids = pending.map((p) => p.id).filter(Boolean); + if (ids.length === 0) return; + await runStream({ messages, rejected_tool_call_ids: ids }); + }, [messages, pending, runStream]); + + const handlePathSelected = useCallback((path: string) => { + const p = normalizePath(path) || '/'; + setInput((prev) => (prev.trim() ? `${prev.trim()} ${p}` : p)); + setPathModalOpen(false); + }, []); + + const messageItems = useMemo(() => { + return messages.filter((m) => { + if (!m || typeof m !== 'object') return false; + const role = typeof (m as any).role === 'string' ? String((m as any).role) : ''; + if (!role || role === 'system') return false; + if (role === 'assistant') { + const text = extractTextContent((m as any).content); + return !!text.trim(); + } + return true; + }); + }, [messages]); + + const runningToolEntries = useMemo(() => Object.entries(runningTools).filter(([id]) => !!id), [runningTools]); + const runningToolCount = runningToolEntries.length; + + const copyToClipboard = useCallback(async (raw: string) => { + try { + await navigator.clipboard.writeText(raw); + message.success(t('Copied')); + } catch (err: any) { + message.error(err?.message || t('Operation failed')); + } + }, [t]); + + const renderToolResultSummary = useCallback((toolName: string, rawContent: string, toolArgs?: Record | null) => { + const data = tryParseJson>(rawContent); + if (!data) return ''; + + if (data.canceled) return t('Canceled'); + if (data.error) return `${t('Error')}: ${String(data.error)}`; + + if (toolName === 'processors_list') { + const processors = Array.isArray(data.processors) ? data.processors : []; + return `${t('Processors')}: ${processors.length}`; + } + if (toolName === 'processors_run') { + const ctx = (() => { + const processorType = typeof toolArgs?.processor_type === 'string' ? toolArgs.processor_type.trim() : ''; + const path = typeof toolArgs?.path === 'string' ? toolArgs.path.trim() : ''; + const parts = [processorType, path].filter(Boolean); + return parts.length ? parts.join(' · ') : ''; + })(); + if (typeof data.task_id === 'string') { + return ctx ? `${t('Task submitted')}: ${ctx} · ${shortId(data.task_id)}` : `${t('Task submitted')}: ${shortId(data.task_id)}`; + } + const taskIds = Array.isArray(data.task_ids) ? data.task_ids : []; + const scheduled = typeof data.scheduled === 'number' ? data.scheduled : taskIds.length; + if (scheduled) return ctx ? `${t('Tasks submitted')}: ${ctx} · ${scheduled}` : `${t('Tasks submitted')}: ${scheduled}`; + return t('Task submitted'); + } + if (toolName === 'vfs_list_dir') { + const path = typeof data.path === 'string' ? data.path : ''; + const entries = Array.isArray(data.entries) ? data.entries : []; + const names = entries + .map((it: any) => String(it?.name || '').trim()) + .filter(Boolean) + .slice(0, 3); + const head = `${t('Directory')}: ${path || '/'}`; + const tail = `${entries.length} ${t('items')}`; + const sample = names.length ? ` · ${names.join(', ')}` : ''; + return `${head} · ${tail}${sample}`; + } + if (toolName === 'vfs_search') { + const query = typeof data.query === 'string' ? data.query : ''; + const items = Array.isArray(data.items) ? data.items : []; + return `${t('Search')}: ${query || '-'} · ${items.length} ${t('results')}`; + } + if (toolName === 'vfs_stat') { + const isDir = Boolean(data.is_dir); + const path = typeof data.path === 'string' ? data.path : ''; + return `${t('Info')}: ${path || '-'} · ${isDir ? t('Folder') : t('File')}`; + } + if (toolName === 'vfs_read_text') { + const path = typeof data.path === 'string' ? data.path : ''; + const length = typeof data.length === 'number' ? data.length : undefined; + const truncated = Boolean(data.truncated); + const tail = length != null ? ` · ${length} ${t('chars')}${truncated ? `(${t('Truncated')})` : ''}` : ''; + return `${t('Read')}: ${path || '-'}${tail}`; + } + if (toolName === 'vfs_write_text') { + const path = typeof data.path === 'string' ? data.path : ''; + const bytes = typeof data.bytes === 'number' ? data.bytes : undefined; + return `${t('Write')}: ${path || '-'}${bytes != null ? ` · ${bytes} bytes` : ''}`; + } + if (toolName === 'vfs_mkdir') { + const path = typeof data.path === 'string' ? data.path : ''; + return `${t('Created')}: ${path || '-'}`; + } + if (toolName === 'vfs_delete') { + const path = typeof data.path === 'string' ? data.path : ''; + return `${t('Deleted')}: ${path || '-'}`; + } + if (toolName === 'vfs_move') { + const src = typeof data.src === 'string' ? data.src : ''; + const dst = typeof data.dst === 'string' ? data.dst : ''; + return `${t('Moved')}: ${src || '-'} → ${dst || '-'}`; + } + if (toolName === 'vfs_copy') { + const src = typeof data.src === 'string' ? data.src : ''; + const dst = typeof data.dst === 'string' ? data.dst : ''; + return `${t('Copied')}: ${src || '-'} → ${dst || '-'}`; + } + if (toolName === 'vfs_rename') { + const src = typeof data.src === 'string' ? data.src : ''; + const dst = typeof data.dst === 'string' ? data.dst : ''; + return `${t('Renamed')}: ${src || '-'} → ${dst || '-'}`; + } + return ''; + }, [t]); + + const renderToolDetails = useCallback((toolKey: string, toolName: string, rawContent: string) => { + const data = tryParseJson>(rawContent); + const showRaw = !!expandedRaw[toolKey]; + const toggleRaw = () => setExpandedRaw((prev) => ({ ...prev, [toolKey]: !prev[toolKey] })); + + const rawJson = (() => { + if (!rawContent?.trim()) return ''; + const parsed = tryParseJson(rawContent); + if (!parsed) return rawContent; + try { + return JSON.stringify(parsed, null, 2); + } catch { + return rawContent; + } + })(); + + const header = ( + + + {showRaw && ( + + )} + + ); + + if (toolName === 'processors_list') { + const processors = Array.isArray(data?.processors) ? data!.processors : []; + return ( +
+ {header} + + ( + + + {String(item?.type || '')} + {String(item?.name || '')} + + + )} + style={{ background: 'transparent' }} + /> + {showRaw && ( + <> + +
{rawJson}
+ + )} +
+ ); + } + + if (toolName === 'vfs_list_dir') { + const path = typeof data?.path === 'string' ? data!.path : '/'; + const entries = Array.isArray(data?.entries) ? data!.entries : []; + const pagination = data?.pagination && typeof data.pagination === 'object' ? data.pagination : null; + return ( +
+ {header} + + + {t('Directory')}: {path} + {pagination?.total != null ? ( + + {t('Total')}: {String(pagination.total)} + + ) : null} + + + { + const name = String(item?.name || ''); + const type = String(item?.type || (item?.is_dir ? 'dir' : 'file')); + return ( + + + + {type} + {name} + + {!item?.is_dir && typeof item?.size === 'number' ? ( + {item.size} bytes + ) : null} + + + ); + }} + style={{ background: 'transparent' }} + /> + {showRaw && ( + <> + +
{rawJson}
+ + )} +
+ ); + } + + if (toolName === 'vfs_search') { + const query = typeof data?.query === 'string' ? data!.query : ''; + const mode = typeof data?.mode === 'string' ? data!.mode : ''; + const items = Array.isArray(data?.items) ? data!.items : []; + const pagination = data?.pagination && typeof data.pagination === 'object' ? data.pagination : null; + return ( +
+ {header} + + + {t('Search')}: {query || '-'} + {t('Mode')}: {mode || '-'} + {pagination?.has_more != null ? ( + + {t('Page')}: {String(pagination.page)} · {t('Has more')}: {String(Boolean(pagination.has_more))} + + ) : null} + + + { + const type = String(item?.source_type || item?.mime || ''); + const path = String(item?.path || ''); + const score = item?.score != null ? Number(item.score) : null; + return ( + + + + {type ? {type} : null} + {path} + + {score != null && !Number.isNaN(score) ? ( + {score.toFixed(3)} + ) : null} + + + ); + }} + style={{ background: 'transparent' }} + /> + {showRaw && ( + <> + +
{rawJson}
+ + )} +
+ ); + } + + if (toolName === 'vfs_read_text') { + const path = typeof data?.path === 'string' ? data!.path : ''; + const content = typeof data?.content === 'string' ? data!.content : ''; + return ( +
+ {header} + + {t('File')}: {path || '-'} +
{content || ''}
+ {showRaw && ( + <> + +
{rawJson}
+ + )} +
+ ); + } + + return ( +
+ {header} + + {showRaw ? ( +
{rawJson}
+ ) : ( + + {extractTextContent(data ?? rawContent) || {t('No content')}} + + )} +
+ ); + }, [copyToClipboard, expandedRaw, t]); + + const renderToolArgsSummary = useCallback((toolName: string, args?: Record | null) => { + const a = args || {}; + if (toolName === 'processors_run') { + const path = typeof a.path === 'string' ? a.path : ''; + return path ? `${t('Path')}: ${path}` : ''; + } + if (toolName === 'vfs_read_text' || toolName === 'vfs_list_dir' || toolName === 'vfs_stat' || toolName === 'vfs_delete' || toolName === 'vfs_mkdir') { + const path = typeof a.path === 'string' ? a.path : ''; + return path ? `${t('Path')}: ${path}` : ''; + } + if (toolName === 'vfs_search') { + const query = typeof a.query === 'string' ? a.query : ''; + return query ? `${t('Search')}: ${query}` : ''; + } + if (toolName === 'vfs_write_text') { + const path = typeof a.path === 'string' ? a.path : ''; + return path ? `${t('Path')}: ${path}` : ''; + } + if (toolName === 'vfs_move' || toolName === 'vfs_copy' || toolName === 'vfs_rename') { + const src = typeof a.src === 'string' ? a.src : ''; + const dst = typeof a.dst === 'string' ? a.dst : ''; + if (src && dst) return `${src} → ${dst}`; + if (src) return src; + if (dst) return dst; + return ''; + } + return ''; + }, [t]); + + return ( + <> + { streamControllerRef.current?.abort(); onOpenChange(false); }} + width={520} + mask={false} + destroyOnHidden + styles={{ + body: { + padding: 8, + background: token.colorBgContainer, + }, + }} + extra={ + + {t('Auto execute')} + + + + } + > + +
+ {messageItems.length === 0 ? ( +
+ } style={{ background: token.colorPrimary }} /> +
+ {t('Start a conversation')} +
+
+ ) : ( +
+ {messageItems.map((m, idx) => { + const role = String((m as any).role); + const isUser = role === 'user'; + const isTool = role === 'tool'; + const toolCallId = typeof (m as any).tool_call_id === 'string' ? String((m as any).tool_call_id) : ''; + const toolInfo = toolCallId ? toolCallsById.get(toolCallId) : null; + const toolName = toolInfo?.name || (toolCallId ? toolNameByIdRef.current[toolCallId] : '') || ''; + const msgKey = toolCallId ? `tool:${toolCallId}` : `${role}:${idx}`; + + if (isTool) { + const rawContent = extractTextContent((m as any).content); + const expanded = !!expandedTools[msgKey]; + const summary = toolName ? renderToolResultSummary(toolName, rawContent, toolInfo?.args || null) : ''; + return ( +
+
+
+ + }> + {t('MCP Tool')} + + }> + {toolName || t('Tool')} + + + +
+ {summary ? ( +
+ {summary} +
+ ) : null} + {expanded && ( +
+ {toolInfo?.args && Object.keys(toolInfo.args).length > 0 && ( +
+ {t('Arguments')} +
+                                    {JSON.stringify(toolInfo.args, null, 2)}
+                                  
+
+ )} + {renderToolDetails(msgKey, toolName || t('Tool'), rawContent)} +
+ )} +
+
+ ); + } + + const text = extractTextContent((m as any).content); + if (isUser) { + return ( +
+
+ {text.trim() ?
{text}
: {t('No content')}} +
+
+ ); + } + + return ( +
+
+ {text.trim() ? ( +
+ {text} +
+ ) : ( + {t('No content')} + )} +
+
+ ); + })} + {runningToolCount > 0 && ( +
+ + {t('Calling tools')} + + {runningToolEntries.slice(0, 2).map(([id, name]) => ( + + {(name || t('Tool'))} #{shortId(id, 4)} + + ))} + {runningToolCount > 2 && ( + +{runningToolCount - 2} + )} + +
+ )} + {pending.length > 0 && ( +
+
+ + + {t('Pending actions')} + + {pending.length} + + + + + +
+ +
+ {pending.map((p) => { + const args = p.arguments || {}; + const key = `pending:${p.id}`; + const expanded = !!expandedTools[key]; + const running = Object.prototype.hasOwnProperty.call(runningTools, p.id); + const summary = renderToolArgsSummary(p.name, args); + return ( +
+
+ + }> + {t('MCP Tool')} + + }> + {p.name} + + {running ? : null} + + + + +
+ {summary ? ( +
+ {summary} +
+ ) : null} + {expanded && ( +
+ {t('Arguments')} +
+                                  {JSON.stringify(args, null, 2)}
+                                
+
+ )} +
+ ); + })} +
+
+ )} +
+ )} +
+ +
+ + + + {effectivePath && ( + {t('Current')}: {effectivePath} + )} + + + setInput(e.target.value)} + placeholder={t('Type a message')} + autoSize={{ minRows: 2, maxRows: 6 }} + disabled={loading || pending.length > 0} + variant="borderless" + onPressEnter={(e) => { + if (e.shiftKey) return; + e.preventDefault(); + void handleSend(); + }} + /> +
+ +
+
+
+
+
+ + setPathModalOpen(false)} + /> + + ); +}); + +export default AiAgentWidget; diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 100595f..541d0ce 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -690,5 +690,40 @@ "App \"{key}\" not found.": "App \"{key}\" not found.", "Open with {app}": "Open with {app}", "Set as default for .{ext}": "Set as default for .{ext}", - "Advanced tokens must be valid JSON": "Advanced tokens must be valid JSON" + "AI Agent": "AI Agent", + "Auto execute": "Auto execute", + "Start a conversation": "Start a conversation", + "No content": "No content", + "Pending actions": "Pending actions", + "Execute": "Execute", + "Execute all": "Execute all", + "Cancel all": "Cancel all", + "Type a message": "Type a message", + "Send": "Send", + "Please confirm pending actions first": "Please confirm pending actions first", + "You": "You", + "Tool": "Tool", + "MCP Tool": "MCP Tool", + "Arguments": "Arguments", + "Raw JSON": "Raw JSON", + "Collapse": "Collapse", + "Copied": "Copied", + "Canceled": "Canceled", + "Tasks submitted": "Tasks submitted", + "Calling tools": "Calling tools", + "Advanced tokens must be valid JSON": "Advanced tokens must be valid JSON", + "Search": "Search", + "Total": "Total", + "Mode": "Mode", + "Has more": "Has more", + "Page": "Page", + "results": "results", + "chars": "chars", + "Truncated": "Truncated", + "Write": "Write", + "Read": "Read", + "Created": "Created", + "Moved": "Moved", + "Renamed": "Renamed", + "Info": "Info" } diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index 14cb2b1..40c5911 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -683,5 +683,40 @@ "App \"{key}\" not found.": "应用 \"{key}\" 不存在。", "Open with {app}": "使用 {app} 打开", "Set as default for .{ext}": "设为该类型(.{ext})默认应用", - "Advanced tokens must be valid JSON": "高级 Token 需为合法 JSON" + "AI Agent": "AI 助手", + "Auto execute": "自动执行", + "Start a conversation": "开始对话", + "No content": "无内容", + "Pending actions": "待确认操作", + "Execute": "执行", + "Execute all": "全部执行", + "Cancel all": "全部取消", + "Type a message": "输入消息", + "Send": "发送", + "Please confirm pending actions first": "请先确认待执行操作", + "You": "你", + "Tool": "工具", + "MCP Tool": "MCP 工具", + "Arguments": "参数", + "Raw JSON": "原始 JSON", + "Collapse": "收起", + "Copied": "已复制", + "Canceled": "已取消", + "Tasks submitted": "已提交任务", + "Calling tools": "正在调用工具", + "Advanced tokens must be valid JSON": "高级 Token 需为合法 JSON", + "Search": "搜索", + "Total": "总计", + "Mode": "模式", + "Has more": "更多", + "Page": "页", + "results": "条结果", + "chars": "字符", + "Truncated": "已截断", + "Write": "写入", + "Read": "读取", + "Created": "已创建", + "Moved": "已移动", + "Renamed": "已重命名", + "Info": "信息" } diff --git a/web/src/layout/TopHeader.tsx b/web/src/layout/TopHeader.tsx index a10a8c1..0b24128 100644 --- a/web/src/layout/TopHeader.tsx +++ b/web/src/layout/TopHeader.tsx @@ -1,5 +1,5 @@ -import { Layout, Button, Dropdown, theme, Flex, Avatar, Typography } from 'antd'; -import { SearchOutlined, MenuUnfoldOutlined, LogoutOutlined, UserOutlined } from '@ant-design/icons'; +import { Layout, Button, Dropdown, theme, Flex, Avatar, Typography, Tooltip } from 'antd'; +import { SearchOutlined, MenuUnfoldOutlined, LogoutOutlined, UserOutlined, RobotOutlined } from '@ant-design/icons'; import { memo, useState } from 'react'; import SearchDialog from './SearchDialog.tsx'; import { authApi } from '../api/auth.ts'; @@ -14,9 +14,10 @@ const { Header } = Layout; export interface TopHeaderProps { collapsed: boolean; onToggle(): void; + onOpenAiAgent(): void; } -const TopHeader = memo(function TopHeader({ collapsed, onToggle }: TopHeaderProps) { +const TopHeader = memo(function TopHeader({ collapsed, onToggle, onOpenAiAgent }: TopHeaderProps) { const { token } = theme.useToken(); const [searchOpen, setSearchOpen] = useState(false); const navigate = useNavigate(); @@ -50,6 +51,15 @@ const TopHeader = memo(function TopHeader({ collapsed, onToggle }: TopHeaderProp setSearchOpen(false)} /> + +