mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-03 06:29:55 +08:00
feat: add agent token provider events
This commit is contained in:
@@ -43,10 +43,11 @@ from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas import AgentLLMProviderEventData, AgentTokensUsageEventData, Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
from app.schemas.types import MessageChannel
|
||||
from app.schemas.types import ChainEventType, EventType, MessageChannel
|
||||
from app.utils.identity import SYSTEM_INTERNAL_USER_ID
|
||||
|
||||
|
||||
@@ -256,6 +257,9 @@ class MoviePilotAgent:
|
||||
self._tool_context: Dict[str, object] = {}
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
self._llm_runtime_config: Optional[Dict[str, Any]] = None
|
||||
self._llm_provider_selection: Dict[str, Any] = {}
|
||||
self._agent_started_at: Optional[datetime] = None
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
@@ -341,6 +345,40 @@ class MoviePilotAgent:
|
||||
)
|
||||
return self._session_usage.to_dict(self.session_id)
|
||||
|
||||
def _send_agent_tokens_usage_event(
|
||||
self,
|
||||
*,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
广播本次 Agent 执行的 token 聚合用量,供配额类插件异步记录。
|
||||
"""
|
||||
try:
|
||||
selection = self._llm_provider_selection or {}
|
||||
event_data = AgentTokensUsageEventData(
|
||||
session_id=self.session_id,
|
||||
selected_provider_id=selection.get("selected_provider_id"),
|
||||
selected_provider_name=selection.get("selected_provider_name"),
|
||||
provider=selection.get("provider") or settings.LLM_PROVIDER,
|
||||
base_url=selection.get("base_url") or settings.LLM_BASE_URL,
|
||||
model=self._session_usage.model or selection.get("model") or settings.LLM_MODEL,
|
||||
input_tokens=self._session_usage.total_input_tokens,
|
||||
output_tokens=self._session_usage.total_output_tokens,
|
||||
total_tokens=self._session_usage.total_tokens,
|
||||
model_call_count=self._session_usage.model_call_count,
|
||||
success=success,
|
||||
error=error,
|
||||
started_at=self._agent_started_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self._agent_started_at
|
||||
else None,
|
||||
finished_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
source=selection.get("source") or "agent",
|
||||
)
|
||||
eventmanager.send_event(EventType.AgentTokensUsage, event_data)
|
||||
except Exception as err:
|
||||
logger.debug(f"广播 Agent Tokens 用量事件失败: {err}")
|
||||
|
||||
@property
|
||||
def is_background(self) -> bool:
|
||||
"""
|
||||
@@ -388,12 +426,113 @@ class MoviePilotAgent:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _initialize_llm(streaming: bool = False):
|
||||
def _get_event_value(event_data: Any, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
从链式事件数据中兼容读取 Pydantic 模型或普通字典字段。
|
||||
"""
|
||||
if isinstance(event_data, dict):
|
||||
return event_data.get(key, default)
|
||||
return getattr(event_data, key, default)
|
||||
|
||||
@staticmethod
|
||||
def _set_event_value(event_data: Any, key: str, value: Any) -> None:
|
||||
"""
|
||||
向链式事件数据中兼容写入 Pydantic 模型或普通字典字段。
|
||||
"""
|
||||
if isinstance(event_data, dict):
|
||||
event_data[key] = value
|
||||
else:
|
||||
setattr(event_data, key, value)
|
||||
|
||||
@classmethod
|
||||
def _clean_optional_text(cls, value: Any) -> Optional[str]:
|
||||
"""
|
||||
标准化事件返回的可选文本字段,空字符串按未返回处理。
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
|
||||
async def _resolve_llm_runtime_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
通过链式事件解析本次 Agent 可用的 LLM 运行时配置。
|
||||
|
||||
若没有插件返回 selected_provider_id,则沿用系统配置,保持既有行为。
|
||||
"""
|
||||
if self._llm_runtime_config is not None:
|
||||
return self._llm_runtime_config
|
||||
|
||||
event_data = AgentLLMProviderEventData(
|
||||
provider=settings.LLM_PROVIDER,
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=settings.LLM_API_KEY,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
base_url_preset=settings.LLM_BASE_URL_PRESET,
|
||||
thinking_level=None,
|
||||
)
|
||||
selected_event = await eventmanager.async_send_event(
|
||||
ChainEventType.AgentLLMProvider,
|
||||
event_data,
|
||||
)
|
||||
resolved_data = selected_event.event_data if selected_event else event_data
|
||||
|
||||
provider = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "provider"))
|
||||
or settings.LLM_PROVIDER
|
||||
)
|
||||
model = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "model"))
|
||||
or settings.LLM_MODEL
|
||||
)
|
||||
api_key = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "api_key"))
|
||||
or settings.LLM_API_KEY
|
||||
)
|
||||
base_url = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "base_url"))
|
||||
or settings.LLM_BASE_URL
|
||||
)
|
||||
base_url_preset = (
|
||||
self._clean_optional_text(self._get_event_value(resolved_data, "base_url_preset"))
|
||||
or settings.LLM_BASE_URL_PRESET
|
||||
)
|
||||
thinking_level = self._clean_optional_text(
|
||||
self._get_event_value(resolved_data, "thinking_level")
|
||||
)
|
||||
selected_provider_id = self._clean_optional_text(
|
||||
self._get_event_value(resolved_data, "selected_provider_id")
|
||||
)
|
||||
selected_provider_name = self._clean_optional_text(
|
||||
self._get_event_value(resolved_data, "selected_provider_name")
|
||||
)
|
||||
source = self._clean_optional_text(self._get_event_value(resolved_data, "source"))
|
||||
|
||||
self._llm_provider_selection = {
|
||||
"selected_provider_id": selected_provider_id,
|
||||
"selected_provider_name": selected_provider_name,
|
||||
"provider": provider,
|
||||
"base_url": base_url,
|
||||
"model": model,
|
||||
"source": source,
|
||||
}
|
||||
self._llm_runtime_config = {
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
"base_url_preset": base_url_preset,
|
||||
"thinking_level": thinking_level,
|
||||
}
|
||||
return self._llm_runtime_config
|
||||
|
||||
async def _initialize_llm(self, streaming: bool = False):
|
||||
"""
|
||||
初始化 LLM
|
||||
:param streaming: 是否启用流式输出
|
||||
"""
|
||||
return await LLMHelper.get_llm(streaming=streaming)
|
||||
runtime_config = await self._resolve_llm_runtime_config()
|
||||
return await LLMHelper.get_llm(streaming=streaming, **runtime_config)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content) -> str:
|
||||
@@ -815,6 +954,11 @@ class MoviePilotAgent:
|
||||
- 渠道不支持消息编辑:非流式 LLM + ainvoke,完成后发送最终回复
|
||||
- 渠道支持消息编辑:流式 LLM + astream,实时推送 token
|
||||
"""
|
||||
execution_success = False
|
||||
execution_error: Optional[str] = None
|
||||
self._agent_started_at = datetime.now()
|
||||
self._llm_runtime_config = None
|
||||
self._llm_provider_selection = {}
|
||||
try:
|
||||
# Agent运行配置
|
||||
agent_config = {
|
||||
@@ -948,11 +1092,14 @@ class MoviePilotAgent:
|
||||
user_id=self.user_id,
|
||||
messages=agent.get_state(agent_config).values.get("messages", []),
|
||||
)
|
||||
execution_success = True
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
execution_error = "任务已取消"
|
||||
return "任务已取消", {}
|
||||
except Exception as e:
|
||||
execution_error = str(e)
|
||||
if self._messages_have_image_input(messages) and self._is_unsupported_image_input_error(e):
|
||||
logger.warning(
|
||||
f"当前模型不支持图片输入,已向用户发送友好提示: {e}"
|
||||
@@ -964,6 +1111,10 @@ class MoviePilotAgent:
|
||||
await self._dispatch_execution_notice(friendly_message)
|
||||
return friendly_message, {}
|
||||
finally:
|
||||
self._send_agent_tokens_usage_event(
|
||||
success=execution_success,
|
||||
error=execution_error,
|
||||
)
|
||||
# 确保停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
|
||||
|
||||
@@ -64,6 +64,53 @@ class ChainEventData(BaseEventData):
|
||||
pass
|
||||
|
||||
|
||||
class AgentLLMProviderEventData(ChainEventData):
|
||||
"""
|
||||
Agent LLM 供应商选择事件数据。
|
||||
|
||||
事件发出方会带入当前系统配置作为默认值;插件可覆盖 provider、base_url、
|
||||
api_key、model 等字段,并通过 selected_provider_id 标记本次选择,方便
|
||||
后续用量事件精确回写到同一个配额条目。
|
||||
"""
|
||||
|
||||
provider: Optional[str] = Field(default=None, description="LLM provider ID")
|
||||
base_url: Optional[str] = Field(default=None, description="API Base URL")
|
||||
api_key: Optional[str] = Field(default=None, description="API Key")
|
||||
model: Optional[str] = Field(default=None, description="模型名称")
|
||||
base_url_preset: Optional[str] = Field(default=None, description="Base URL 预设ID")
|
||||
thinking_level: Optional[str] = Field(default=None, description="思考模式级别")
|
||||
selected_provider_id: Optional[str] = Field(default=None, description="插件侧供应商ID")
|
||||
selected_provider_name: Optional[str] = Field(default=None, description="插件侧供应商名称")
|
||||
source: Optional[str] = Field(default=None, description="选择来源")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="扩展元数据")
|
||||
|
||||
|
||||
class AgentTokensUsageEventData(BaseEventData):
|
||||
"""
|
||||
Agent Tokens 用量广播事件数据。
|
||||
|
||||
用量事件不携带 API Key,只携带选择事件返回的 selected_provider_id 以及
|
||||
聚合后的 token 统计,避免把密钥扩散给广播订阅者。
|
||||
"""
|
||||
|
||||
session_id: str = Field(..., description="Agent 会话ID")
|
||||
selected_provider_id: Optional[str] = Field(default=None, description="插件侧供应商ID")
|
||||
selected_provider_name: Optional[str] = Field(default=None, description="插件侧供应商名称")
|
||||
provider: Optional[str] = Field(default=None, description="实际 LLM provider ID")
|
||||
base_url: Optional[str] = Field(default=None, description="API Base URL")
|
||||
model: Optional[str] = Field(default=None, description="模型名称")
|
||||
input_tokens: int = Field(default=0, description="输入 tokens")
|
||||
output_tokens: int = Field(default=0, description="输出 tokens")
|
||||
total_tokens: int = Field(default=0, description="总 tokens")
|
||||
model_call_count: int = Field(default=0, description="模型调用次数")
|
||||
success: bool = Field(default=False, description="Agent 执行是否成功")
|
||||
error: Optional[str] = Field(default=None, description="失败原因")
|
||||
started_at: Optional[str] = Field(default=None, description="开始时间")
|
||||
finished_at: Optional[str] = Field(default=None, description="结束时间")
|
||||
source: str = Field(default="agent", description="事件来源")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="扩展元数据")
|
||||
|
||||
|
||||
class AuthCredentials(ChainEventData):
|
||||
"""
|
||||
AuthVerification 事件的数据模型
|
||||
|
||||
@@ -105,6 +105,8 @@ class EventType(Enum):
|
||||
MessageAction = "message.action"
|
||||
# 执行工作流
|
||||
WorkflowExecute = "workflow.execute"
|
||||
# Agent Tokens 用量
|
||||
AgentTokensUsage = "agent.tokens.usage"
|
||||
|
||||
|
||||
# EventType中文名称翻译字典
|
||||
@@ -139,6 +141,7 @@ EVENT_TYPE_NAMES = {
|
||||
EventType.ConfigChanged: "配置项更新",
|
||||
EventType.MessageAction: "消息交互动作",
|
||||
EventType.WorkflowExecute: "执行工作流",
|
||||
EventType.AgentTokensUsage: "Agent Tokens 用量",
|
||||
}
|
||||
|
||||
|
||||
@@ -174,6 +177,8 @@ class ChainEventType(Enum):
|
||||
WorkflowExecution = "workflow.execution"
|
||||
# 存储操作选择
|
||||
StorageOperSelection = "storage.operation"
|
||||
# Agent LLM 供应商选择
|
||||
AgentLLMProvider = "agent.llm.provider"
|
||||
|
||||
|
||||
# 系统配置Key字典
|
||||
|
||||
192
tests/test_agent_tokens_events.py
Normal file
192
tests/test_agent_tokens_events.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agent import MoviePilotAgent
|
||||
from app.agent.memory import memory_manager
|
||||
from app.plugins.agenttokens import AgentTokens
|
||||
from app.schemas.types import ChainEventType, EventType
|
||||
|
||||
|
||||
class _FakeGraphState:
|
||||
"""提供 LangGraph get_state 测试替身。"""
|
||||
|
||||
def __init__(self, messages):
|
||||
self.values = {"messages": messages}
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""提供非流式 Agent 执行测试替身。"""
|
||||
|
||||
def __init__(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
async def ainvoke(self, _payload, config=None):
|
||||
"""模拟成功完成 Agent 调用。"""
|
||||
return None
|
||||
|
||||
def get_state(self, _config):
|
||||
"""返回测试消息状态。"""
|
||||
return _FakeGraphState(self._messages)
|
||||
|
||||
|
||||
class _FakeFailingAgent(_FakeAgent):
|
||||
"""提供失败 Agent 执行测试替身。"""
|
||||
|
||||
async def ainvoke(self, _payload, config=None):
|
||||
"""模拟 Agent 调用失败。"""
|
||||
raise RuntimeError("llm failed")
|
||||
|
||||
|
||||
class AgentTokensEventsTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_plugin_sidebar_nav_respects_config(self):
|
||||
"""插件侧边栏入口应受 show_sidebar_nav 配置控制。"""
|
||||
plugin = AgentTokens()
|
||||
|
||||
with patch.object(plugin, "update_config"):
|
||||
plugin.init_plugin(
|
||||
{
|
||||
"enabled": True,
|
||||
"show_sidebar_nav": False,
|
||||
"providers": [],
|
||||
}
|
||||
)
|
||||
self.assertEqual([], plugin.get_sidebar_nav())
|
||||
|
||||
plugin.init_plugin(
|
||||
{
|
||||
"enabled": True,
|
||||
"show_sidebar_nav": True,
|
||||
"providers": [],
|
||||
}
|
||||
)
|
||||
nav = plugin.get_sidebar_nav()
|
||||
|
||||
self.assertEqual("Agent Tokens 管理", nav[0]["title"])
|
||||
|
||||
async def test_initialize_llm_uses_chain_event_selection(self):
|
||||
"""Agent 初始化 LLM 时应优先使用链式事件返回的供应商配置。"""
|
||||
agent = MoviePilotAgent(session_id="agent-tokens-test", user_id="user-1")
|
||||
fake_llm = object()
|
||||
|
||||
async def select_provider(etype, data):
|
||||
"""模拟 Agent Tokens 插件写入供应商配置。"""
|
||||
self.assertEqual(ChainEventType.AgentLLMProvider, etype)
|
||||
data.provider = "openai"
|
||||
data.base_url = "https://tokens.example.com/v1"
|
||||
data.api_key = "sk-agent-token"
|
||||
data.model = "free-model"
|
||||
data.base_url_preset = None
|
||||
data.selected_provider_id = "provider-1"
|
||||
data.selected_provider_name = "Free Provider"
|
||||
data.source = "AgentTokens"
|
||||
return SimpleNamespace(event_data=data)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.agent.eventmanager.async_send_event",
|
||||
new=AsyncMock(side_effect=select_provider),
|
||||
) as send_event,
|
||||
patch("app.agent.LLMHelper.get_llm", new=AsyncMock(return_value=fake_llm)) as get_llm,
|
||||
):
|
||||
result = await agent._initialize_llm(streaming=True)
|
||||
second_result = await agent._initialize_llm(streaming=False)
|
||||
|
||||
self.assertIs(result, fake_llm)
|
||||
self.assertIs(second_result, fake_llm)
|
||||
send_event.assert_awaited_once()
|
||||
self.assertEqual(2, get_llm.await_count)
|
||||
get_llm.assert_any_await(
|
||||
streaming=True,
|
||||
provider="openai",
|
||||
model="free-model",
|
||||
api_key="sk-agent-token",
|
||||
base_url="https://tokens.example.com/v1",
|
||||
base_url_preset=None,
|
||||
thinking_level=None,
|
||||
)
|
||||
self.assertEqual("provider-1", agent._llm_provider_selection["selected_provider_id"])
|
||||
|
||||
async def test_execute_agent_broadcasts_usage_on_success(self):
|
||||
"""Agent 执行成功后应广播聚合 token 用量事件。"""
|
||||
agent = MoviePilotAgent(session_id="usage-success", user_id="user-1")
|
||||
agent._should_stream = lambda: False
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
stop_streaming=AsyncMock(return_value=(False, ""))
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
async def create_agent(_streaming=False, streaming=False):
|
||||
"""模拟创建 Agent 时完成供应商选择和用量统计。"""
|
||||
agent._llm_provider_selection = {
|
||||
"selected_provider_id": "provider-1",
|
||||
"selected_provider_name": "Free Provider",
|
||||
"provider": "openai",
|
||||
"base_url": "https://tokens.example.com/v1",
|
||||
"model": "free-model",
|
||||
"source": "AgentTokens",
|
||||
}
|
||||
agent._record_usage(
|
||||
{
|
||||
"has_usage": True,
|
||||
"model": "free-model",
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 8,
|
||||
"total_tokens": 20,
|
||||
}
|
||||
)
|
||||
return _FakeAgent([AIMessage(content="ok")])
|
||||
|
||||
with (
|
||||
patch.object(agent, "_create_agent", new=create_agent),
|
||||
patch.object(memory_manager, "save_agent_messages"),
|
||||
patch("app.agent.eventmanager.send_event") as send_event,
|
||||
):
|
||||
await agent._execute_agent([])
|
||||
|
||||
send_event.assert_called_once()
|
||||
self.assertEqual(EventType.AgentTokensUsage, send_event.call_args.args[0])
|
||||
usage = send_event.call_args.args[1]
|
||||
self.assertTrue(usage.success)
|
||||
self.assertEqual("provider-1", usage.selected_provider_id)
|
||||
self.assertEqual(12, usage.input_tokens)
|
||||
self.assertEqual(8, usage.output_tokens)
|
||||
self.assertEqual(20, usage.total_tokens)
|
||||
|
||||
async def test_execute_agent_broadcasts_usage_on_failure(self):
|
||||
"""Agent 执行失败后仍应广播用量事件。"""
|
||||
agent = MoviePilotAgent(session_id="usage-failure", user_id="user-1")
|
||||
agent._should_stream = lambda: False
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
stop_streaming=AsyncMock(return_value=(False, ""))
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
|
||||
async def create_agent(_streaming=False, streaming=False):
|
||||
"""模拟创建 Agent 时已选中供应商但执行失败。"""
|
||||
agent._llm_provider_selection = {
|
||||
"selected_provider_id": "provider-2",
|
||||
"selected_provider_name": "Backup Provider",
|
||||
"provider": "openai",
|
||||
"base_url": "https://backup.example.com/v1",
|
||||
"model": "backup-model",
|
||||
"source": "AgentTokens",
|
||||
}
|
||||
return _FakeFailingAgent([])
|
||||
|
||||
with (
|
||||
patch.object(agent, "_create_agent", new=create_agent),
|
||||
patch("app.agent.eventmanager.send_event") as send_event,
|
||||
):
|
||||
result, _ = await agent._execute_agent([])
|
||||
|
||||
self.assertIn("智能助手执行失败", result)
|
||||
send_event.assert_called_once()
|
||||
self.assertEqual(EventType.AgentTokensUsage, send_event.call_args.args[0])
|
||||
usage = send_event.call_args.args[1]
|
||||
self.assertFalse(usage.success)
|
||||
self.assertEqual("provider-2", usage.selected_provider_id)
|
||||
self.assertIn("llm failed", usage.error)
|
||||
Reference in New Issue
Block a user