From 6b9790026c4aacee47106f297a2b9915513624af Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 29 Apr 2026 08:50:48 +0800 Subject: [PATCH] refine plugin agent tool responsibilities --- app/agent/tools/factory.py | 6 + app/agent/tools/impl/_plugin_tool_utils.py | 217 ++++++++++++++++++ app/agent/tools/impl/install_plugin.py | 118 ++++++++++ .../tools/impl/query_installed_plugins.py | 99 +++++--- app/agent/tools/impl/query_market_plugins.py | 113 +++++++++ app/agent/tools/impl/uninstall_plugin.py | 84 +++++++ tests/test_agent_plugin_tools.py | 100 ++++++++ 7 files changed, 709 insertions(+), 28 deletions(-) create mode 100644 app/agent/tools/impl/install_plugin.py create mode 100644 app/agent/tools/impl/query_market_plugins.py create mode 100644 app/agent/tools/impl/uninstall_plugin.py diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 86a78ec9..9d992031 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -52,11 +52,14 @@ from app.agent.tools.impl.write_file import WriteFileTool from app.agent.tools.impl.read_file import ReadFileTool from app.agent.tools.impl.browse_webpage import BrowseWebpageTool from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool +from app.agent.tools.impl.query_market_plugins import QueryMarketPluginsTool from app.agent.tools.impl.query_plugin_capabilities import QueryPluginCapabilitiesTool from app.agent.tools.impl.query_plugin_config import QueryPluginConfigTool from app.agent.tools.impl.update_plugin_config import UpdatePluginConfigTool from app.agent.tools.impl.reload_plugin import ReloadPluginTool from app.agent.tools.impl.query_plugin_data import QueryPluginDataTool +from app.agent.tools.impl.install_plugin import InstallPluginTool +from app.agent.tools.impl.uninstall_plugin import UninstallPluginTool from app.agent.tools.impl.run_slash_command import RunSlashCommandTool from app.agent.tools.impl.list_slash_commands import ListSlashCommandsTool from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiersTool @@ -149,11 +152,14 @@ class MoviePilotToolFactory: ReadFileTool, BrowseWebpageTool, QueryInstalledPluginsTool, + QueryMarketPluginsTool, QueryPluginCapabilitiesTool, QueryPluginConfigTool, UpdatePluginConfigTool, ReloadPluginTool, QueryPluginDataTool, + InstallPluginTool, + UninstallPluginTool, RunSlashCommandTool, ListSlashCommandsTool, QueryCustomIdentifiersTool, diff --git a/app/agent/tools/impl/_plugin_tool_utils.py b/app/agent/tools/impl/_plugin_tool_utils.py index 9b1138c3..ba25a215 100644 --- a/app/agent/tools/impl/_plugin_tool_utils.py +++ b/app/agent/tools/impl/_plugin_tool_utils.py @@ -1,15 +1,21 @@ """插件 Agent 工具共享辅助方法""" import json +import shutil from typing import Any, Optional +from app.core.config import settings from app.core.plugin import PluginManager +from app.db.systemconfig_oper import SystemConfigOper +from app.helper.plugin import PluginHelper +from app.schemas.types import SystemConfigKey # 默认只向智能体返回一个可读预览,避免超大插件数据挤爆上下文窗口。 DEFAULT_PLUGIN_DATA_PREVIEW_CHARS = 12_000 MAX_PLUGIN_DATA_PREVIEW_CHARS = 50_000 PLUGIN_DATA_KEY_PREVIEW_LIMIT = 50 PLUGIN_DATA_TRUNCATION_SUFFIX = "\n...(插件数据内容过长,已截断)" +DEFAULT_PLUGIN_CANDIDATE_LIMIT = 10 def get_plugin_snapshot(plugin_id: str) -> Optional[dict[str, Any]]: @@ -71,3 +77,214 @@ def reload_plugin_runtime(plugin_id: str) -> None: Scheduler().update_plugin_job(plugin_id) Command().init_commands(plugin_id) register_plugin_api(plugin_id) + + +def summarize_plugin(plugin: Any) -> dict[str, Any]: + """ + 提取插件对象中对 Agent 有价值的摘要字段。 + """ + repo_url = getattr(plugin, "repo_url", None) + return { + "id": getattr(plugin, "id", None), + "plugin_name": getattr(plugin, "plugin_name", None), + "plugin_desc": getattr(plugin, "plugin_desc", None), + "plugin_version": getattr(plugin, "plugin_version", None), + "plugin_author": getattr(plugin, "plugin_author", None), + "installed": bool(getattr(plugin, "installed", False)), + "has_update": bool(getattr(plugin, "has_update", False)), + "state": bool(getattr(plugin, "state", False)), + "repo_url": repo_url, + "source": "local_repo" if PluginHelper.is_local_repo_url(repo_url) else "market", + } + + +async def load_market_plugins(force_refresh: bool = False) -> list[Any]: + """ + 聚合插件市场与本地插件仓库中的候选插件。 + """ + plugin_manager = PluginManager() + online_plugins = await plugin_manager.async_get_online_plugins(force=force_refresh) + local_repo_plugins = plugin_manager.get_local_repo_plugins() + if not online_plugins and not local_repo_plugins: + return [] + return plugin_manager.process_plugins_list(online_plugins + local_repo_plugins, []) + + +def list_installed_plugins() -> list[Any]: + """ + 返回当前已安装插件列表。 + """ + plugin_manager = PluginManager() + return [plugin for plugin in plugin_manager.get_local_plugins() if plugin.installed] + + +def _normalize_text(value: Optional[str]) -> str: + return (value or "").strip().lower() + + +def is_exact_plugin_match(plugin: Any, query: str) -> bool: + """ + 精确匹配插件 ID 或插件名称,用于安全地自动选择候选。 + """ + normalized_query = _normalize_text(query) + return normalized_query in { + _normalize_text(getattr(plugin, "id", None)), + _normalize_text(getattr(plugin, "plugin_name", None)), + } + + +def search_plugin_candidates(query: str, plugins: list[Any]) -> list[dict[str, Any]]: + """ + 按插件 ID、名称、描述和作者搜索候选,并返回打分结果。 + """ + normalized_query = _normalize_text(query) + if not normalized_query: + return [] + + tokens = [token for token in normalized_query.replace("-", " ").split() if token] + matches: list[dict[str, Any]] = [] + + for plugin in plugins: + plugin_id = _normalize_text(getattr(plugin, "id", None)) + plugin_name = _normalize_text(getattr(plugin, "plugin_name", None)) + plugin_desc = _normalize_text(getattr(plugin, "plugin_desc", None)) + plugin_author = _normalize_text(getattr(plugin, "plugin_author", None)) + haystack = "\n".join([plugin_id, plugin_name, plugin_desc, plugin_author]) + + score = 0 + if normalized_query == plugin_id: + score = 100 + elif normalized_query == plugin_name: + score = 95 + elif plugin_id.startswith(normalized_query): + score = 85 + elif plugin_name.startswith(normalized_query): + score = 80 + elif normalized_query in plugin_id: + score = 75 + elif normalized_query in plugin_name: + score = 70 + elif tokens and all(token in plugin_name for token in tokens): + score = 68 + elif tokens and all(token in plugin_id for token in tokens): + score = 66 + elif normalized_query in plugin_desc: + score = 45 + elif normalized_query in plugin_author: + score = 40 + elif tokens and all(token in haystack for token in tokens): + score = 35 + + if score <= 0: + continue + + matches.append( + { + "plugin": plugin, + "score": score, + "exact": is_exact_plugin_match(plugin, normalized_query), + } + ) + + return sorted( + matches, + key=lambda item: ( + -item["score"], + not item["exact"], + -int(bool(getattr(item["plugin"], "has_update", False))), + -int(bool(getattr(item["plugin"], "installed", False))), + -int(getattr(item["plugin"], "add_time", 0) or 0), + ), + ) + + +def summarize_candidates(matches: list[dict[str, Any]], limit: int = DEFAULT_PLUGIN_CANDIDATE_LIMIT) -> list[dict[str, Any]]: + """ + 压缩候选列表,避免一次性把完整市场数据返回给 Agent。 + """ + return [ + { + **summarize_plugin(item["plugin"]), + "score": item["score"], + "exact": item["exact"], + } + for item in matches[:limit] + ] + + +async def install_plugin_runtime( + plugin_id: str, repo_url: Optional[str], force: bool = False +) -> tuple[bool, str, bool]: + """ + 按现有插件接口的行为安装插件,并刷新运行态注册信息。 + """ + install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or [] + plugin_manager = PluginManager() + plugin_helper = PluginHelper() + + refreshed_only = False + if not force and plugin_id in plugin_manager.get_plugin_ids(): + refreshed_only = True + await plugin_helper.async_install_reg(pid=plugin_id, repo_url=repo_url) + message = "插件已存在,已刷新加载" + else: + if not repo_url: + return False, "没有传入仓库地址,无法正确安装插件,请检查配置", False + state, message = await plugin_helper.async_install( + pid=plugin_id, + repo_url=repo_url, + force_install=force, + ) + if not state: + return False, message, False + + if plugin_id not in install_plugins: + install_plugins.append(plugin_id) + await SystemConfigOper().async_set( + SystemConfigKey.UserInstalledPlugins, install_plugins + ) + + reload_plugin_runtime(plugin_id) + return True, message or "插件安装成功", refreshed_only + + +async def uninstall_plugin_runtime(plugin_id: str) -> dict[str, Any]: + """ + 按现有卸载逻辑移除插件,并清理运行态注册与分组信息。 + """ + from app.api.endpoints.plugin import _remove_plugin_from_folders, remove_plugin_api + from app.scheduler import Scheduler + + config_oper = SystemConfigOper() + install_plugins = config_oper.get(SystemConfigKey.UserInstalledPlugins) or [] + if plugin_id in install_plugins: + install_plugins = [plugin for plugin in install_plugins if plugin != plugin_id] + await config_oper.async_set(SystemConfigKey.UserInstalledPlugins, install_plugins) + + remove_plugin_api(plugin_id) + Scheduler().remove_plugin_job(plugin_id) + + plugin_manager = PluginManager() + plugin_class = plugin_manager.plugins.get(plugin_id) + was_clone = bool(getattr(plugin_class, "is_clone", False)) + clone_files_removed = False + + if was_clone: + plugin_manager.delete_plugin_config(plugin_id) + plugin_manager.delete_plugin_data(plugin_id) + plugin_base_dir = settings.ROOT_PATH / "app" / "plugins" / plugin_id.lower() + if plugin_base_dir.exists(): + try: + shutil.rmtree(plugin_base_dir) + plugin_manager.plugins.pop(plugin_id, None) + clone_files_removed = True + except Exception: + clone_files_removed = False + + _remove_plugin_from_folders(plugin_id) + plugin_manager.remove_plugin(plugin_id) + + return { + "was_clone": was_clone, + "clone_files_removed": clone_files_removed, + } diff --git a/app/agent/tools/impl/install_plugin.py b/app/agent/tools/impl/install_plugin.py new file mode 100644 index 00000000..a2609c4c --- /dev/null +++ b/app/agent/tools/impl/install_plugin.py @@ -0,0 +1,118 @@ +"""安装插件工具""" + +import json +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.agent.tools.impl._plugin_tool_utils import ( + get_plugin_snapshot, + install_plugin_runtime, + load_market_plugins, + summarize_plugin, +) +from app.log import logger + + +class InstallPluginInput(BaseModel): + """安装插件工具的输入参数模型""" + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + plugin_id: str = Field( + ..., + description="Exact plugin ID to install. Use query_market_plugins first to find the correct plugin_id.", + ) + force: bool = Field( + False, + description="Whether to force reinstall or upgrade the specified plugin.", + ) + force_refresh_market: bool = Field( + False, + description="Whether to refresh plugin market caches before reading the market list.", + ) + + +class InstallPluginTool(MoviePilotTool): + name: str = "install_plugin" + description: str = ( + "Install a plugin by exact plugin_id from the plugin market or local plugin repositories. " + "Use query_market_plugins first when you need filtering or discovery." + ) + require_admin: bool = True + args_schema: Type[BaseModel] = InstallPluginInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + plugin_id = kwargs.get("plugin_id") + return f"安装插件: {plugin_id or '未知插件'}" + + async def run( + self, + plugin_id: str, + force: bool = False, + force_refresh_market: bool = False, + **kwargs, + ) -> str: + logger.info( + f"执行工具: {self.name}, 参数: plugin_id={plugin_id}, force={force}" + ) + + try: + plugins = await load_market_plugins(force_refresh=force_refresh_market) + if not plugins: + return json.dumps( + {"success": False, "message": "当前插件市场没有可用插件"}, + ensure_ascii=False, + ) + + candidate = next((plugin for plugin in plugins if plugin.id == plugin_id), None) + if not candidate: + return json.dumps( + { + "success": False, + "message": f"未在插件市场中找到插件: {plugin_id}。请先调用 query_market_plugins 确认 plugin_id。", + }, + ensure_ascii=False, + ) + + success, message, refreshed_only = await install_plugin_runtime( + candidate.id, + getattr(candidate, "repo_url", None), + force=force, + ) + if not success: + return json.dumps( + { + "success": False, + "plugin": summarize_plugin(candidate), + "message": message, + }, + ensure_ascii=False, + indent=2, + ) + + plugin_snapshot = get_plugin_snapshot(candidate.id) + if refreshed_only and getattr(candidate, "has_update", False) and not force: + message = "插件已安装,当前仅刷新加载;如需升级到市场新版本,请设置 force=true" + + return json.dumps( + { + "success": True, + "message": message, + "force": force, + "refreshed_only": refreshed_only, + "plugin": summarize_plugin(candidate), + "runtime": plugin_snapshot, + }, + ensure_ascii=False, + indent=2, + ) + except Exception as e: + logger.error(f"安装插件失败: {e}", exc_info=True) + return json.dumps( + {"success": False, "message": f"安装插件时发生错误: {str(e)}"}, + ensure_ascii=False, + ) diff --git a/app/agent/tools/impl/query_installed_plugins.py b/app/agent/tools/impl/query_installed_plugins.py index 218b2fd8..c090df0d 100644 --- a/app/agent/tools/impl/query_installed_plugins.py +++ b/app/agent/tools/impl/query_installed_plugins.py @@ -6,7 +6,13 @@ from typing import Optional, Type from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool -from app.core.plugin import PluginManager +from app.agent.tools.impl._plugin_tool_utils import ( + DEFAULT_PLUGIN_CANDIDATE_LIMIT, + list_installed_plugins, + search_plugin_candidates, + summarize_candidates, + summarize_plugin, +) from app.log import logger @@ -17,49 +23,86 @@ class QueryInstalledPluginsInput(BaseModel): ..., description="Clear explanation of why this tool is being used in the current context", ) + query: Optional[str] = Field( + None, + description="Optional keyword to filter installed plugins by plugin ID, name, description, or author.", + ) + max_results: Optional[int] = Field( + DEFAULT_PLUGIN_CANDIDATE_LIMIT, + description="Maximum number of plugins to return. Defaults to 10.", + ) class QueryInstalledPluginsTool(MoviePilotTool): name: str = "query_installed_plugins" description: str = ( - "Query all installed plugins in MoviePilot. Returns a list of installed plugins with their ID, name, " - "description, version, author, running state, and other information. " - "Use this tool to discover what plugins are available before querying plugin capabilities or running plugin commands." + "Query installed plugins in MoviePilot. Returns all installed plugins or filters them by keywords. " + "Use this tool to find the exact plugin_id before uninstall_plugin or other plugin management tools are used." ) require_admin: bool = True args_schema: Type[BaseModel] = QueryInstalledPluginsInput def get_tool_message(self, **kwargs) -> Optional[str]: """生成友好的提示消息""" + query = kwargs.get("query") + if query: + return f"查询已安装插件: {query}" return "查询已安装插件" @staticmethod - def _list_installed_plugins() -> list[dict]: - """读取已加载插件的内存快照。""" - plugin_manager = PluginManager() - local_plugins = plugin_manager.get_local_plugins() - installed_plugins = [plugin for plugin in local_plugins if plugin.installed] - return [ - { - "id": plugin.id, - "plugin_name": plugin.plugin_name, - "plugin_desc": plugin.plugin_desc, - "plugin_version": plugin.plugin_version, - "plugin_author": plugin.plugin_author, - "state": plugin.state, - "has_page": plugin.has_page, - } - for plugin in installed_plugins - ] + def _clamp_results(max_results: Optional[int]) -> int: + if max_results is None: + return DEFAULT_PLUGIN_CANDIDATE_LIMIT + return max(1, min(int(max_results), 200)) - async def run(self, **kwargs) -> str: - logger.info(f"执行工具: {self.name}") + async def run( + self, + query: Optional[str] = None, + max_results: Optional[int] = DEFAULT_PLUGIN_CANDIDATE_LIMIT, + **kwargs, + ) -> str: + logger.info(f"执行工具: {self.name}, 参数: query={query}") try: - installed_plugins = self._list_installed_plugins() + installed_plugins = list_installed_plugins() if not installed_plugins: - return "当前没有已安装的插件" - result_json = json.dumps(installed_plugins, ensure_ascii=False, indent=2) - return result_json + return json.dumps( + {"success": False, "message": "当前没有已安装的插件"}, + ensure_ascii=False, + ) + + limit = self._clamp_results(max_results) + if query: + matches = search_plugin_candidates(query, installed_plugins) + return json.dumps( + { + "success": True, + "query": query, + "total_installed": len(installed_plugins), + "match_count": len(matches), + "truncated": len(matches) > limit, + "plugins": summarize_candidates(matches, limit=limit), + }, + ensure_ascii=False, + indent=2, + ) + + plugin_summaries = [ + summarize_plugin(plugin) for plugin in installed_plugins[:limit] + ] + return json.dumps( + { + "success": True, + "total_installed": len(installed_plugins), + "returned_count": len(plugin_summaries), + "truncated": len(installed_plugins) > limit, + "plugins": plugin_summaries, + }, + ensure_ascii=False, + indent=2, + ) except Exception as e: logger.error(f"查询已安装插件失败: {e}", exc_info=True) - return f"查询已安装插件时发生错误: {str(e)}" + return json.dumps( + {"success": False, "message": f"查询已安装插件时发生错误: {str(e)}"}, + ensure_ascii=False, + ) diff --git a/app/agent/tools/impl/query_market_plugins.py b/app/agent/tools/impl/query_market_plugins.py new file mode 100644 index 00000000..46875411 --- /dev/null +++ b/app/agent/tools/impl/query_market_plugins.py @@ -0,0 +1,113 @@ +"""查询插件市场工具""" + +import json +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.agent.tools.impl._plugin_tool_utils import ( + DEFAULT_PLUGIN_CANDIDATE_LIMIT, + load_market_plugins, + search_plugin_candidates, + summarize_candidates, + summarize_plugin, +) +from app.log import logger + + +class QueryMarketPluginsInput(BaseModel): + """查询插件市场工具的输入参数模型""" + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + query: Optional[str] = Field( + None, + description="Optional keyword to filter plugin market results by plugin ID, name, description, or author.", + ) + max_results: Optional[int] = Field( + DEFAULT_PLUGIN_CANDIDATE_LIMIT, + description="Maximum number of plugins to return. Defaults to 10.", + ) + force_refresh: Optional[bool] = Field( + False, + description="Whether to refresh plugin market caches before querying.", + ) + + +class QueryMarketPluginsTool(MoviePilotTool): + name: str = "query_market_plugins" + description: str = ( + "Query available plugins from the plugin market and local plugin repositories. " + "Can return the full plugin list or filter by keywords before install_plugin is used." + ) + require_admin: bool = True + args_schema: Type[BaseModel] = QueryMarketPluginsInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + query = kwargs.get("query") + if query: + return f"查询插件市场: {query}" + return "查询插件市场全部插件" + + @staticmethod + def _clamp_results(max_results: Optional[int]) -> int: + if max_results is None: + return DEFAULT_PLUGIN_CANDIDATE_LIMIT + return max(1, min(int(max_results), 200)) + + async def run( + self, + query: Optional[str] = None, + max_results: Optional[int] = DEFAULT_PLUGIN_CANDIDATE_LIMIT, + force_refresh: bool = False, + **kwargs, + ) -> str: + logger.info( + f"执行工具: {self.name}, 参数: query={query}, force_refresh={force_refresh}" + ) + + try: + plugins = await load_market_plugins(force_refresh=force_refresh) + if not plugins: + return json.dumps( + {"success": False, "message": "当前插件市场没有可用插件"}, + ensure_ascii=False, + ) + + limit = self._clamp_results(max_results) + if query: + matches = search_plugin_candidates(query, plugins) + return json.dumps( + { + "success": True, + "query": query, + "total_available": len(plugins), + "match_count": len(matches), + "truncated": len(matches) > limit, + "plugins": summarize_candidates(matches, limit=limit), + }, + ensure_ascii=False, + indent=2, + ) + + plugin_summaries = [summarize_plugin(plugin) for plugin in plugins[:limit]] + return json.dumps( + { + "success": True, + "total_available": len(plugins), + "returned_count": len(plugin_summaries), + "truncated": len(plugins) > limit, + "plugins": plugin_summaries, + }, + ensure_ascii=False, + indent=2, + ) + except Exception as e: + logger.error(f"查询插件市场失败: {e}", exc_info=True) + return json.dumps( + {"success": False, "message": f"查询插件市场时发生错误: {str(e)}"}, + ensure_ascii=False, + ) diff --git a/app/agent/tools/impl/uninstall_plugin.py b/app/agent/tools/impl/uninstall_plugin.py new file mode 100644 index 00000000..1b64a0be --- /dev/null +++ b/app/agent/tools/impl/uninstall_plugin.py @@ -0,0 +1,84 @@ +"""卸载插件工具""" + +import json +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.agent.tools.impl._plugin_tool_utils import ( + list_installed_plugins, + summarize_plugin, + uninstall_plugin_runtime, +) +from app.log import logger + + +class UninstallPluginInput(BaseModel): + """卸载插件工具的输入参数模型""" + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + plugin_id: str = Field( + ..., + description="Exact plugin ID to uninstall. Use query_installed_plugins first to find the correct plugin_id.", + ) + + +class UninstallPluginTool(MoviePilotTool): + name: str = "uninstall_plugin" + description: str = ( + "Uninstall an installed plugin by exact plugin_id. " + "Use query_installed_plugins first when you need filtering or discovery." + ) + require_admin: bool = True + args_schema: Type[BaseModel] = UninstallPluginInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + plugin_id = kwargs.get("plugin_id") + return f"卸载插件: {plugin_id or '未知插件'}" + + async def run( + self, + plugin_id: str, + **kwargs, + ) -> str: + logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}") + + try: + plugins = list_installed_plugins() + if not plugins: + return json.dumps( + {"success": False, "message": "当前没有已安装的插件"}, + ensure_ascii=False, + ) + + candidate = next((plugin for plugin in plugins if plugin.id == plugin_id), None) + if not candidate: + return json.dumps( + { + "success": False, + "message": f"未找到已安装插件: {plugin_id}。请先调用 query_installed_plugins 确认 plugin_id。", + }, + ensure_ascii=False, + ) + + cleanup_result = await uninstall_plugin_runtime(candidate.id) + return json.dumps( + { + "success": True, + "message": f"插件 {candidate.id} 已卸载", + "plugin": summarize_plugin(candidate), + **cleanup_result, + }, + ensure_ascii=False, + indent=2, + ) + except Exception as e: + logger.error(f"卸载插件失败: {e}", exc_info=True) + return json.dumps( + {"success": False, "message": f"卸载插件时发生错误: {str(e)}"}, + ensure_ascii=False, + ) diff --git a/tests/test_agent_plugin_tools.py b/tests/test_agent_plugin_tools.py index 40a70d6d..1a3c35df 100644 --- a/tests/test_agent_plugin_tools.py +++ b/tests/test_agent_plugin_tools.py @@ -4,9 +4,13 @@ import unittest from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch +from app.agent.tools.impl.install_plugin import InstallPluginTool +from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool +from app.agent.tools.impl.query_market_plugins import QueryMarketPluginsTool from app.agent.tools.impl.query_plugin_config import QueryPluginConfigTool from app.agent.tools.impl.query_plugin_data import QueryPluginDataTool from app.agent.tools.impl.reload_plugin import ReloadPluginTool +from app.agent.tools.impl.uninstall_plugin import UninstallPluginTool from app.agent.tools.impl.update_plugin_config import UpdatePluginConfigTool @@ -20,6 +24,57 @@ class TestAgentPluginTools(unittest.TestCase): "state": state, } + @staticmethod + def _market_plugin(plugin_id: str, plugin_name: str, installed: bool = False): + return SimpleNamespace( + id=plugin_id, + plugin_name=plugin_name, + plugin_desc=f"{plugin_name} description", + plugin_version="1.0.0", + plugin_author="author", + installed=installed, + has_update=False, + state=installed, + repo_url="https://example.com/market", + add_time=1, + ) + + def test_query_market_plugins_filters_candidates(self): + tool = QueryMarketPluginsTool(session_id="session-1", user_id="10001") + plugins = [ + self._market_plugin("DemoPlugin", "Demo Plugin"), + self._market_plugin("OtherPlugin", "Other Plugin"), + ] + + with patch( + "app.agent.tools.impl.query_market_plugins.load_market_plugins", + new=AsyncMock(return_value=plugins), + ): + result = asyncio.run(tool.run(query="demo")) + + payload = json.loads(result) + self.assertTrue(payload["success"]) + self.assertEqual(payload["match_count"], 1) + self.assertEqual(payload["plugins"][0]["id"], "DemoPlugin") + + def test_query_installed_plugins_filters_candidates(self): + tool = QueryInstalledPluginsTool(session_id="session-1", user_id="10001") + plugins = [ + self._market_plugin("DemoPlugin", "Demo Plugin", installed=True), + self._market_plugin("OtherPlugin", "Other Plugin", installed=True), + ] + + with patch( + "app.agent.tools.impl.query_installed_plugins.list_installed_plugins", + return_value=plugins, + ): + result = asyncio.run(tool.run(query="demo")) + + payload = json.loads(result) + self.assertTrue(payload["success"]) + self.assertEqual(payload["match_count"], 1) + self.assertEqual(payload["plugins"][0]["id"], "DemoPlugin") + def test_query_plugin_config_returns_saved_config_and_default_model(self): tool = QueryPluginConfigTool(session_id="session-1", user_id="10001") plugin_manager = MagicMock() @@ -92,6 +147,51 @@ class TestAgentPluginTools(unittest.TestCase): self.assertFalse(payload["state"]) reload_plugin_runtime.assert_called_once_with("DemoPlugin") + def test_install_plugin_installs_market_candidate(self): + tool = InstallPluginTool(session_id="session-1", user_id="10001") + candidate = self._market_plugin("DemoPlugin", "Demo Plugin") + + with patch( + "app.agent.tools.impl.install_plugin.load_market_plugins", + new=AsyncMock(return_value=[candidate]), + ), patch( + "app.agent.tools.impl.install_plugin.install_plugin_runtime", + new=AsyncMock(return_value=(True, "插件安装完成", False)), + ) as install_runtime, patch( + "app.agent.tools.impl.install_plugin.get_plugin_snapshot", + return_value=self._plugin_snapshot(), + ): + result = asyncio.run(tool.run(plugin_id="DemoPlugin")) + + payload = json.loads(result) + self.assertTrue(payload["success"]) + self.assertEqual(payload["plugin"]["id"], "DemoPlugin") + install_runtime.assert_awaited_once_with( + "DemoPlugin", "https://example.com/market", force=False + ) + + def test_uninstall_plugin_uninstalls_installed_candidate(self): + tool = UninstallPluginTool(session_id="session-1", user_id="10001") + installed_plugin = self._market_plugin( + "DemoPlugin", "Demo Plugin", installed=True + ) + + with patch( + "app.agent.tools.impl.uninstall_plugin.list_installed_plugins", + return_value=[installed_plugin], + ), patch( + "app.agent.tools.impl.uninstall_plugin.uninstall_plugin_runtime", + new=AsyncMock( + return_value={"was_clone": False, "clone_files_removed": False} + ), + ) as uninstall_runtime: + result = asyncio.run(tool.run(plugin_id="DemoPlugin")) + + payload = json.loads(result) + self.assertTrue(payload["success"]) + self.assertEqual(payload["plugin"]["id"], "DemoPlugin") + uninstall_runtime.assert_awaited_once_with("DemoPlugin") + def test_query_plugin_data_truncates_large_payload(self): tool = QueryPluginDataTool(session_id="session-1", user_id="10001") plugin_data_oper = MagicMock()