mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-16 10:57:36 +08:00
feat: Implement AI Agent with enhanced tool processing capabilities (#89)
* feat: Implement AI Agent with tool processing capabilities - Added tools for listing and running processors in the agent. - Created data models for agent chat requests and tool calls. - Developed API integration for agent chat and streaming responses. - Built the AI Agent widget with a user interface for interaction. - Styled the agent components for better user experience. * feat: 增强 AI 助手工具功能,添加文件操作和搜索功能,更新界面显示 * feat: 更新 AI 助手组件 * feat: 更新 AiAgentWidget 组件样式,调整背景和边距以提升界面一致性
This commit is contained in:
4
domain/agent/__init__.py
Normal file
4
domain/agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .api import router
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
39
domain/agent/api.py
Normal file
39
domain/agent/api.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.response import success
|
||||
from domain.agent.service import AgentService
|
||||
from domain.agent.types import AgentChatRequest
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await AgentService.chat(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute"])
|
||||
async def chat_stream(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return StreamingResponse(
|
||||
AgentService.chat_stream(payload, current_user),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
448
domain/agent/service.py
Normal file
448
domain/agent/service.py
Normal file
@@ -0,0 +1,448 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.agent.tools import get_tool, openai_tools, tool_result_to_content
|
||||
from domain.agent.types import AgentChatRequest, PendingToolCall
|
||||
from domain.ai.inference import MissingModelError, chat_completion, chat_completion_stream
|
||||
from domain.ai.service import AIProviderService
|
||||
from domain.auth.types import User
|
||||
|
||||
|
||||
def _normalize_path(p: Optional[str]) -> Optional[str]:
|
||||
if not p:
|
||||
return None
|
||||
s = str(p).strip()
|
||||
if not s:
|
||||
return None
|
||||
s = s.replace("\\", "/")
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _build_system_prompt(current_path: Optional[str]) -> str:
|
||||
lines = [
|
||||
"你是 Foxel 的 AI 助手。",
|
||||
"你可以通过工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。",
|
||||
"",
|
||||
"可用工具:",
|
||||
"- vfs_list_dir:浏览目录(列出 entries + pagination)。",
|
||||
"- vfs_stat:查看文件/目录信息。",
|
||||
"- vfs_read_text:读取文本文件内容(不支持二进制)。",
|
||||
"- vfs_search:搜索文件(vector/filename)。",
|
||||
"- vfs_write_text:写入文本文件内容(覆盖)。",
|
||||
"- vfs_mkdir:创建目录。",
|
||||
"- vfs_delete:删除文件或目录。",
|
||||
"- vfs_move:移动路径。",
|
||||
"- vfs_copy:复制路径。",
|
||||
"- vfs_rename:重命名路径。",
|
||||
"- processors_list:获取可用处理器列表(含 type/name/config_schema/produces_file/supports_directory)。",
|
||||
"- processors_run:运行处理器处理文件或目录(会返回 task_id 或 task_ids)。",
|
||||
"",
|
||||
"规则:",
|
||||
"1) 读操作(vfs_list_dir/vfs_stat/vfs_read_text/vfs_search)可直接调用工具。",
|
||||
"2) 写/改/删操作(vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run)默认需要用户确认;只有在开启自动执行时才应直接执行。",
|
||||
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
|
||||
"4) 修改文件内容:先读取(vfs_read_text)→给出改动点→确认后再写入(vfs_write_text)。",
|
||||
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
|
||||
"6) 回答保持简洁中文。",
|
||||
]
|
||||
if current_path:
|
||||
lines.append("")
|
||||
lines.append(f"当前文件管理目录:{current_path}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list):
|
||||
return message
|
||||
|
||||
changed = False
|
||||
for idx, call in enumerate(tool_calls):
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = call.get("id")
|
||||
if isinstance(call_id, str) and call_id.strip():
|
||||
continue
|
||||
call["id"] = f"call_{idx}"
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
message["tool_calls"] = tool_calls
|
||||
return message
|
||||
|
||||
|
||||
def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall:
|
||||
call_id = str(tool_call.get("id") or "")
|
||||
fn = tool_call.get("function") or {}
|
||||
name = str((fn.get("name") if isinstance(fn, dict) else None) or "")
|
||||
raw_args = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
arguments: Dict[str, Any] = {}
|
||||
if isinstance(raw_args, str) and raw_args.strip():
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
if isinstance(parsed, dict):
|
||||
arguments = parsed
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return PendingToolCall(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
return idx, msg
|
||||
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
|
||||
|
||||
|
||||
def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id.strip():
|
||||
ids.add(tool_call_id)
|
||||
return ids
|
||||
|
||||
|
||||
async def _choose_chat_ability() -> str:
|
||||
tools_model = await AIProviderService.get_default_model("tools")
|
||||
return "tools" if tools_model else "chat"
|
||||
|
||||
|
||||
def _sse(event: str, data: Any) -> bytes:
|
||||
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
|
||||
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
try:
|
||||
assistant = await chat_completion(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
)
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc
|
||||
|
||||
assistant = _ensure_tool_call_ids(assistant)
|
||||
internal_messages.append(assistant)
|
||||
new_messages.append(assistant)
|
||||
|
||||
tool_calls = assistant.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
if pending:
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
try:
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
assistant_event_id = uuid.uuid4().hex
|
||||
yield _sse("assistant_start", {"id": assistant_event_id})
|
||||
|
||||
assistant_message: Dict[str, Any] | None = None
|
||||
try:
|
||||
async for event in chat_completion_stream(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
):
|
||||
if event.get("type") == "delta":
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
|
||||
elif event.get("type") == "message":
|
||||
msg = event.get("message")
|
||||
if isinstance(msg, dict):
|
||||
assistant_message = msg
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc
|
||||
|
||||
if not assistant_message:
|
||||
assistant_message = {"role": "assistant", "content": ""}
|
||||
|
||||
assistant_message = _ensure_tool_call_ids(assistant_message)
|
||||
internal_messages.append(assistant_message)
|
||||
new_messages.append(assistant_message)
|
||||
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
|
||||
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
if pending:
|
||||
yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]})
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
413
domain/agent/tools.py
Normal file
413
domain/agent/tools.py
Normal file
@@ -0,0 +1,413 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from domain.processors.service import ProcessorService
|
||||
from domain.processors.types import ProcessDirectoryRequest, ProcessRequest
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.search.search_service import VirtualFSSearchService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolSpec:
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
requires_confirmation: bool
|
||||
handler: Callable[[Dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
async def _processors_list(_: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {"processors": ProcessorService.list_processors()}
|
||||
|
||||
|
||||
async def _processors_run(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = str(args.get("path") or "")
|
||||
processor_type = str(args.get("processor_type") or "")
|
||||
config = args.get("config")
|
||||
if not isinstance(config, dict):
|
||||
config = {}
|
||||
|
||||
save_to = args.get("save_to")
|
||||
save_to = str(save_to) if isinstance(save_to, str) and save_to.strip() else None
|
||||
|
||||
max_depth = args.get("max_depth")
|
||||
max_depth_value: Optional[int] = None
|
||||
if max_depth is not None:
|
||||
try:
|
||||
max_depth_value = int(max_depth)
|
||||
except (TypeError, ValueError):
|
||||
max_depth_value = None
|
||||
|
||||
suffix = args.get("suffix")
|
||||
suffix_value = str(suffix) if isinstance(suffix, str) and suffix.strip() else None
|
||||
|
||||
overwrite_value = args.get("overwrite")
|
||||
overwrite = bool(overwrite_value) if overwrite_value is not None else None
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(path)
|
||||
if is_dir and (max_depth_value is not None or suffix_value is not None):
|
||||
req = ProcessDirectoryRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
overwrite=True if overwrite is None else overwrite,
|
||||
max_depth=max_depth_value,
|
||||
suffix=suffix_value,
|
||||
)
|
||||
result = await ProcessorService.process_directory(req)
|
||||
return {"mode": "directory", **result}
|
||||
|
||||
req = ProcessRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
save_to=save_to,
|
||||
overwrite=False if overwrite is None else overwrite,
|
||||
)
|
||||
result = await ProcessorService.process_file(req)
|
||||
return {"mode": "file", **result}
|
||||
|
||||
|
||||
def _normalize_vfs_path(value: Any) -> str:
|
||||
s = str(value or "").strip().replace("\\", "/")
|
||||
if not s:
|
||||
return ""
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _require_vfs_path(value: Any, field: str) -> str:
|
||||
path = _normalize_vfs_path(value)
|
||||
if not path:
|
||||
raise ValueError(f"missing_{field}")
|
||||
return path
|
||||
|
||||
|
||||
async def _vfs_list_dir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _normalize_vfs_path(args.get("path") or "/") or "/"
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 50)
|
||||
sort_by = str(args.get("sort_by") or "name")
|
||||
sort_order = str(args.get("sort_order") or "asc")
|
||||
return await VirtualFSService.list_directory(path, page, page_size, sort_by, sort_order)
|
||||
|
||||
|
||||
async def _vfs_stat(args: Dict[str, Any]) -> Any:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.stat(path)
|
||||
|
||||
|
||||
async def _vfs_read_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
max_chars = int(args.get("max_chars") or 8000)
|
||||
|
||||
data = await VirtualFSService.read_file(path)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
try:
|
||||
text = bytes(data).decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
return {"error": "binary_or_invalid_text", "path": path}
|
||||
elif isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = str(data)
|
||||
|
||||
original_len = len(text)
|
||||
truncated = original_len > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
return {
|
||||
"path": path,
|
||||
"encoding": encoding,
|
||||
"content": text,
|
||||
"truncated": truncated,
|
||||
"length": original_len,
|
||||
}
|
||||
|
||||
|
||||
async def _vfs_write_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
if path == "/":
|
||||
raise ValueError("invalid_path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
content = str(args.get("content") or "")
|
||||
data = content.encode(encoding)
|
||||
await VirtualFSService.write_file(path, data)
|
||||
return {"written": True, "path": path, "encoding": encoding, "bytes": len(data)}
|
||||
|
||||
|
||||
async def _vfs_mkdir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.mkdir(path)
|
||||
|
||||
|
||||
async def _vfs_delete(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.delete(path)
|
||||
|
||||
|
||||
async def _vfs_move(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.move(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_copy(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.copy(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_rename(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.rename(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_search(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
q = str(args.get("q") or "").strip()
|
||||
if not q:
|
||||
raise ValueError("missing_q")
|
||||
mode = str(args.get("mode") or "vector")
|
||||
top_k = int(args.get("top_k") or 10)
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 10)
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"processors_list": ToolSpec(
|
||||
name="processors_list",
|
||||
description="获取可用处理器列表(type/name/config_schema 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_processors_list,
|
||||
),
|
||||
"processors_run": ToolSpec(
|
||||
name="processors_run",
|
||||
description=(
|
||||
"运行处理器处理文件或目录。"
|
||||
" 对目录可选 max_depth/suffix;对文件可选 overwrite/save_to。"
|
||||
" 返回任务 id(去任务队列查看进度)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件或目录路径(绝对路径,如 /foo/bar)"},
|
||||
"processor_type": {"type": "string", "description": "处理器类型(例如 image_watermark)"},
|
||||
"config": {"type": "object", "description": "处理器配置,按 processors_list 返回的 config_schema 填写"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖原文件/目录内文件"},
|
||||
"save_to": {"type": "string", "description": "保存到指定路径(仅文件模式,且 overwrite=false 时使用)"},
|
||||
"max_depth": {"type": "integer", "description": "目录遍历深度(仅目录模式)"},
|
||||
"suffix": {"type": "string", "description": "目录批处理时的输出后缀(仅 produces_file 且 overwrite=false)"},
|
||||
},
|
||||
"required": ["path", "processor_type"],
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_processors_run,
|
||||
),
|
||||
"vfs_list_dir": ToolSpec(
|
||||
name="vfs_list_dir",
|
||||
description="浏览目录(列出 entries + pagination)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
"page": {"type": "integer", "description": "页码(从 1 开始)"},
|
||||
"page_size": {"type": "integer", "description": "每页条数"},
|
||||
"sort_by": {"type": "string", "description": "排序字段:name/size/mtime"},
|
||||
"sort_order": {"type": "string", "description": "排序顺序:asc/desc"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_list_dir,
|
||||
),
|
||||
"vfs_stat": ToolSpec(
|
||||
name="vfs_stat",
|
||||
description="查看文件/目录信息(size/mtime/is_dir/has_thumbnail/vector_index 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_stat,
|
||||
),
|
||||
"vfs_read_text": ToolSpec(
|
||||
name="vfs_read_text",
|
||||
description="读取文本文件内容(解码失败视为二进制,返回 error)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
"max_chars": {"type": "integer", "description": "最多返回的字符数(默认 8000)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_read_text,
|
||||
),
|
||||
"vfs_write_text": ToolSpec(
|
||||
name="vfs_write_text",
|
||||
description="写入文本文件内容(会覆盖目标文件)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"content": {"type": "string", "description": "要写入的文本内容"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_write_text,
|
||||
),
|
||||
"vfs_mkdir": ToolSpec(
|
||||
name="vfs_mkdir",
|
||||
description="创建目录。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_mkdir,
|
||||
),
|
||||
"vfs_delete": ToolSpec(
|
||||
name="vfs_delete",
|
||||
description="删除文件或目录(由底层适配器决定是否递归)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar 或 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_delete,
|
||||
),
|
||||
"vfs_move": ToolSpec(
|
||||
name="vfs_move",
|
||||
description="移动路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_move,
|
||||
),
|
||||
"vfs_copy": ToolSpec(
|
||||
name="vfs_copy",
|
||||
description="复制路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_copy,
|
||||
),
|
||||
"vfs_rename": ToolSpec(
|
||||
name="vfs_rename",
|
||||
description="重命名路径(本质是同目录 move)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_rename,
|
||||
),
|
||||
"vfs_search": ToolSpec(
|
||||
name="vfs_search",
|
||||
description="搜索文件(mode=vector 或 filename)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {"type": "string", "description": "搜索关键词"},
|
||||
"mode": {"type": "string", "description": "搜索模式:vector/filename(默认 vector)"},
|
||||
"top_k": {"type": "integer", "description": "返回数量(vector 模式使用,默认 10)"},
|
||||
"page": {"type": "integer", "description": "页码(filename 模式使用,默认 1)"},
|
||||
"page_size": {"type": "integer", "description": "分页大小(filename 模式使用,默认 10)"},
|
||||
},
|
||||
"required": ["q"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_search,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tool(name: str) -> Optional[ToolSpec]:
|
||||
return TOOLS.get(name)
|
||||
|
||||
|
||||
def openai_tools() -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for spec in TOOLS.values():
|
||||
out.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"parameters": spec.parameters,
|
||||
},
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def tool_result_to_content(result: Any) -> str:
|
||||
if result is None:
|
||||
return ""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
try:
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return json.dumps({"result": str(result)}, ensure_ascii=False)
|
||||
23
domain/agent/types.py
Normal file
23
domain/agent/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentChatContext(BaseModel):
|
||||
current_path: Optional[str] = None
|
||||
|
||||
|
||||
class AgentChatRequest(BaseModel):
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
auto_execute: bool = False
|
||||
approved_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
rejected_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
context: Optional[AgentChatContext] = None
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
requires_confirmation: bool = True
|
||||
|
||||
Reference in New Issue
Block a user