diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index 9217c212..fa641083 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -2,11 +2,9 @@ import json from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from typing import Annotated, Any, Literal, Union, NotRequired +from typing import Annotated, Any, NotRequired from langchain.agents.middleware.types import ( - AgentMiddleware, AgentState, ContextT, ModelRequest, @@ -16,78 +14,18 @@ from langchain.agents.middleware.types import ( from langchain.agents.middleware.types import ( PrivateStateAttr, # noqa ) +from langchain.agents.middleware.tool_selection import ( + DEFAULT_SYSTEM_PROMPT, + LLMToolSelectorMiddleware, +) from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.runtime import Runtime -from pydantic import Field, TypeAdapter from typing_extensions import TypedDict # noqa from app.log import logger -DEFAULT_SYSTEM_PROMPT = ( - "Your goal is to select the most relevant tools for answering the user's query." -) - - -@dataclass -class _SelectionRequest: - """Prepared inputs for tool selection.""" - - available_tools: list[BaseTool] - system_message: str - last_user_message: HumanMessage - model: BaseChatModel - valid_tool_names: list[str] - - -def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]: - """Create a structured output schema for tool selection. - - Args: - tools: Available tools to include in the schema. - - Returns: - `TypeAdapter` for a schema where each tool name is a `Literal` with its - description. - - Raises: - AssertionError: If `tools` is empty. - """ - if not tools: - msg = "Invalid usage: tools must be non-empty" - raise AssertionError(msg) - - # Create a Union of Annotated Literal types for each tool name with description - # For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...] - literals = [ - Annotated[Literal[tool.name], Field(description=tool.description)] - for tool in tools # noqa - ] - selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007 - - description = "Tools to use. Place the most relevant tools first." - - class ToolSelectionResponse(TypedDict): - """Use to select relevant tools.""" - - tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type] - - return TypeAdapter(ToolSelectionResponse) - - -def _render_tool_list(tools: list[BaseTool]) -> str: - """Format tools as markdown list. - - Args: - tools: Tools to format. - - Returns: - Markdown string with each tool on a new line. - """ - return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools) - class ToolSelectionState(AgentState): """工具筛选中间件私有状态。""" @@ -102,9 +40,7 @@ class ToolSelectionStateUpdate(TypedDict): selected_tool_names: list[str] | None -class ToolSelectorMiddleware( - AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] -): +class ToolSelectorMiddleware(LLMToolSelectorMiddleware): """ 为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。 @@ -129,94 +65,19 @@ class ToolSelectorMiddleware( def __init__( self, - model: BaseChatModel, + model: BaseChatModel | str | None = None, system_prompt: str = DEFAULT_SYSTEM_PROMPT, selection_tools: list[Any] | None = None, max_tools: int | None = None, always_include: list[str] | None = None, ) -> None: - super().__init__() - self.model = model - self.system_prompt = system_prompt - self.max_tools = max_tools - self.always_include = always_include or [] - self.selection_tools = selection_tools or [] - - def _prepare_selection_request( - self, request: ModelRequest[ContextT] - ) -> _SelectionRequest | None: - """Prepare inputs for tool selection. - - Args: - request: the model request. - - Returns: - `SelectionRequest` with prepared inputs, or `None` if no selection is - needed. - - Raises: - ValueError: If tools in `always_include` are not found in the request. - AssertionError: If no user message is found in the request messages. - """ - # If no tools available, return None - if not request.tools or len(request.tools) == 0: - return None - - # Filter to only BaseTool instances (exclude provider-specific tool dicts) - base_tools = [tool for tool in request.tools if not isinstance(tool, dict)] - - # Validate that always_include tools exist - if self.always_include: - available_tool_names = {tool.name for tool in base_tools} - missing_tools = [ - name for name in self.always_include if name not in available_tool_names - ] - if missing_tools: - msg = ( - f"Tools in always_include not found in request: {missing_tools}. " - f"Available tools: {sorted(available_tool_names)}" - ) - raise ValueError(msg) - - # Separate tools that are always included from those available for selection - available_tools = [ - tool for tool in base_tools if tool.name not in self.always_include - ] - - # If no tools available for selection, return None - if not available_tools: - return None - - system_message = self.system_prompt - # If there's a max_tools limit, append instructions to the system prompt - if self.max_tools is not None: - system_message += ( - f"\nIMPORTANT: List the tool names in order of relevance, " - f"with the most relevant first. " - f"If you exceed the maximum number of tools, " - f"only the first {self.max_tools} will be used." - ) - - # Get the last user message from the conversation history - last_user_message: HumanMessage - for message in reversed(request.messages): - if isinstance(message, HumanMessage): - last_user_message = message - break - else: - msg = "No user message found in request messages" - raise AssertionError(msg) - - model = self.model or request.model - valid_tool_names = [tool.name for tool in available_tools] - - return _SelectionRequest( - available_tools=available_tools, - system_message=system_message, - last_user_message=last_user_message, + super().__init__( model=model, - valid_tool_names=valid_tool_names, + system_prompt=system_prompt, + max_tools=max_tools, + always_include=always_include, ) + self.selection_tools = selection_tools or [] def _process_selection_response( self, @@ -225,46 +86,29 @@ class ToolSelectorMiddleware( valid_tool_names: list[str], request: ModelRequest[ContextT], ) -> ModelRequest[ContextT]: - """Process the selection response and return filtered `ModelRequest`.""" - selected_tool_names: list[str] = [] - invalid_tool_selections = [] - - for tool_name in response["tools"]: - if tool_name not in valid_tool_names: - invalid_tool_selections.append(tool_name) - continue - - # Only add if not already selected and within max_tools limit - if tool_name not in selected_tool_names and ( - self.max_tools is None or len(selected_tool_names) < self.max_tools - ): - selected_tool_names.append(tool_name) - - if invalid_tool_selections: - msg = f"Model selected invalid tools: {invalid_tool_selections}" - raise ValueError(msg) - - # Filter tools based on selection and append always-included tools - if selected_tool_names: - selected_tools: list[BaseTool] = [ - tool for tool in available_tools if tool.name in selected_tool_names - ] - else: - # 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具 + """ + 处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。 + """ + if response.get("tools") == []: logger.warning("工具筛选结果为空,将恢复使用所有工具。") - selected_tools = available_tools - always_included_tools: list[BaseTool] = [ - tool - for tool in request.tools - if not isinstance(tool, dict) and tool.name in self.always_include - ] - selected_tools.extend(always_included_tools) + always_included_tools: list[BaseTool] = [ + tool + for tool in request.tools + if not isinstance(tool, dict) and tool.name in self.always_include + ] + provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] - # Also preserve any provider-specific tool dicts from the original request - provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] + return request.override( + tools=[*available_tools, *always_included_tools, *provider_tools] + ) - return request.override(tools=[*selected_tools, *provider_tools]) + return super()._process_selection_response( + response, + available_tools, + valid_tool_names, + request, + ) @staticmethod def _is_deepseek_compatible_model(model: BaseChatModel) -> bool: diff --git a/app/agent/tools/impl/query_transfer_history.py b/app/agent/tools/impl/query_transfer_history.py index 2163877c..8ec5c3a1 100644 --- a/app/agent/tools/impl/query_transfer_history.py +++ b/app/agent/tools/impl/query_transfer_history.py @@ -3,7 +3,6 @@ import json from typing import Optional, Type -import jieba from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool @@ -11,6 +10,7 @@ from app.db import AsyncSessionFactory from app.db.models.transferhistory import TransferHistory from app.log import logger from app.schemas.types import media_type_to_agent +from app.utils.jieba import cut as jieba_cut class QueryTransferHistoryInput(BaseModel): @@ -69,8 +69,8 @@ class QueryTransferHistoryTool(MoviePilotTool): async with AsyncSessionFactory() as db: # 处理标题搜索 if title: - # 使用 jieba 分词处理标题 - words = jieba.cut(title, HMM=False) + # 使用 fast-jieba 分词处理标题。 + words = jieba_cut(title, HMM=False) title_search = "%".join(words) # 查询记录 result = await TransferHistory.async_list_by_title( diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index 27a14138..328f96ee 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -3,7 +3,6 @@ import time from pathlib import Path from typing import List, Any, Optional -import jieba from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session @@ -24,6 +23,7 @@ from app.db.user_oper import ( ) from app.helper.progress import ProgressHelper from app.schemas.types import EventType +from app.utils.jieba import cut as jieba_cut router = APIRouter() @@ -272,7 +272,7 @@ async def transfer_history( db, title=like_pattern, page=page, count=count, status=status, wildcard=True ) else: - words = jieba.cut(title, HMM=False) + words = jieba_cut(title, HMM=False) like_pattern = "%".join(words) total = await TransferHistory.async_count_by_title( db, title=like_pattern, status=status diff --git a/app/core/meta/metaanime.py b/app/core/meta/metaanime.py index 240caf3b..cec554be 100644 --- a/app/core/meta/metaanime.py +++ b/app/core/meta/metaanime.py @@ -1,13 +1,13 @@ import re import traceback -import zhconv import anitopy from app.core.meta.customization import CustomizationMatcher from app.core.meta.metabase import MetaBase from app.core.meta.releasegroup import ReleaseGroupsMatcher from app.log import logger from app.utils.string import StringUtils +from app.utils.zhconv import convert as zhconv_convert from app.schemas.types import MediaType @@ -219,7 +219,7 @@ class MetaAnime(MetaBase): # 截掉分类 first_item = title.split(']')[0] if first_item and re.search(r"[动漫画纪录片电影视连续剧集日美韩中港台海外亚洲华语大陆综艺原盘高清]{2,}|TV|Animation|Movie|Documentar|Anime", - zhconv.convert(first_item, "zh-hans"), + zhconv_convert(first_item, "zh-hans"), re.IGNORECASE): title = re.sub(r"^[^]]*]", "", title).strip() # 去掉大小 diff --git a/app/modules/douban/__init__.py b/app/modules/douban/__init__.py index 986521fd..42d75d80 100644 --- a/app/modules/douban/__init__.py +++ b/app/modules/douban/__init__.py @@ -2,7 +2,6 @@ import re from typing import List, Optional, Tuple, Union import cn2an -import zhconv from app import schemas from app.core.config import settings @@ -19,6 +18,7 @@ from app.schemas.types import MediaType, ModuleType, MediaRecognizeType from app.utils.common import retry from app.utils.http import RequestUtils from app.utils.limit import rate_limit_exponential +from app.utils.zhconv import convert as zhconv_convert class DoubanModule(_ModuleBase): @@ -77,7 +77,7 @@ class DoubanModule(_ModuleBase): 准备搜索名称列表,保留中英文名称分别识别且按顺序去重的历史行为。 """ # 简体名称 - zh_name = zhconv.convert(meta.cn_name, "zh-hans") if meta.cn_name else None + zh_name = zhconv_convert(meta.cn_name, "zh-hans") if meta.cn_name else None # 使用中英文名分别识别,去重去空,但要保持顺序 return list(dict.fromkeys([k for k in [meta.cn_name, zh_name, meta.en_name] if k])) diff --git a/app/modules/themoviedb/__init__.py b/app/modules/themoviedb/__init__.py index 3f870b7c..9aab94b9 100644 --- a/app/modules/themoviedb/__init__.py +++ b/app/modules/themoviedb/__init__.py @@ -2,7 +2,6 @@ import re from typing import Optional, List, Tuple, Union, Dict import cn2an -import zhconv from app import schemas from app.core.config import settings @@ -17,6 +16,7 @@ from app.modules.themoviedb.tmdbapi import TmdbApi from app.schemas.category import CategoryConfig from app.schemas.types import MediaType, MediaImageType, ModuleType, MediaRecognizeType from app.utils.http import RequestUtils +from app.utils.zhconv import convert as zhconv_convert _DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") @@ -116,7 +116,7 @@ class TheMovieDbModule(_ModuleBase): 准备搜索名称列表 """ # 简体名称 - zh_name = zhconv.convert(meta.cn_name, "zh-hans") if meta.cn_name else None + zh_name = zhconv_convert(meta.cn_name, "zh-hans") if meta.cn_name else None # 使用中英文名分别识别,去重去空,但要保持顺序 return list(dict.fromkeys([k for k in [meta.cn_name, zh_name, meta.en_name] if k])) diff --git a/app/modules/themoviedb/tmdbapi.py b/app/modules/themoviedb/tmdbapi.py index 34f3743f..5f78e5a8 100644 --- a/app/modules/themoviedb/tmdbapi.py +++ b/app/modules/themoviedb/tmdbapi.py @@ -2,12 +2,11 @@ import re import traceback from typing import Optional, List -import zhconv - from app.core.config import settings from app.log import logger from app.schemas.types import MediaType from app.utils.string import StringUtils +from app.utils.zhconv import convert as zhconv_convert from .tmdbv3api import TMDb, Search, Movie, TV, Season, Episode, Discover, Trending, Person, Collection from .tmdbv3api.exceptions import TMDbException @@ -726,7 +725,7 @@ class TmdbApi: if iso_3166_1 == "CN": title = alternative_title.get("title") if title and StringUtils.is_chinese(title) \ - and zhconv.convert(title, "zh-hans") == title: + and zhconv_convert(title, "zh-hans") == title: return title return tmdbinfo.get("title") if tmdbinfo.get("media_type") == MediaType.MOVIE else tmdbinfo.get("name") diff --git a/app/utils/jieba.py b/app/utils/jieba.py new file mode 100644 index 00000000..4a834042 --- /dev/null +++ b/app/utils/jieba.py @@ -0,0 +1,10 @@ +"""中文分词工具。""" + +from fast_jieba import cut as fast_jieba_cut + + +def cut(text: str, HMM: bool = True, cut_all: bool = False) -> list[str]: + """ + 使用 fast-jieba 执行中文分词,并兼容 jieba.cut 的常用参数名。 + """ + return fast_jieba_cut(text, hmm=HMM, cut_all=cut_all) diff --git a/app/utils/zhconv.py b/app/utils/zhconv.py new file mode 100644 index 00000000..55413870 --- /dev/null +++ b/app/utils/zhconv.py @@ -0,0 +1,10 @@ +"""中文简繁转换工具。""" + +from zhconv_rs import zhconv as _zhconv # pylint: disable=no-name-in-module + + +def convert(text: str, target: str) -> str: + """ + 使用 zhconv-rs 执行中文简繁转换,并隔离第三方包的函数名差异。 + """ + return _zhconv(text, target) diff --git a/requirements.in b/requirements.in index cac8c796..e0af5e36 100644 --- a/requirements.in +++ b/requirements.in @@ -14,10 +14,10 @@ alembic~=1.16.2 anyio~=4.10.0 bcrypt~=4.0.1 regex~=2024.11.6 -cn2an~=0.5.19 +cn2an~=0.5.24 dateparser~=1.2.2 python-dateutil~=2.8.2 -zhconv~=1.4.3 +zhconv-rs~=0.4.1 anitopy~=2.1.1 requests[socks]~=2.32.4 urllib3~=2.5.0 @@ -41,6 +41,7 @@ pyTelegramBotAPI~=4.27.0 telegramify-markdown~=0.5.2 cloakbrowser~=0.3.28 torrentool~=1.2.0 +fast-bencode~=1.1.7 slack-bolt~=1.23.0 slack-sdk~=3.35.0 discord.py==2.6.4 @@ -63,7 +64,7 @@ pywebpush~=2.0.3 aiosqlite~=0.21.0 psycopg2-binary~=2.9.10 asyncpg~=0.30.0 -jieba~=0.42.1 +fast-jieba~=0.4.0 rsa~=4.9 redis~=6.2.0 async_timeout~=5.0.1; python_full_version < "3.11.3" @@ -75,17 +76,17 @@ pympler~=1.1 smbprotocol~=1.15.0 setproctitle~=1.3.6 httpx[socks,http2]~=0.28.1 -langchain~=1.2.15 -langchain-core~=1.3.2 -langchain-community~=0.4.1 -langchain-anthropic~=1.4.2 -langchain-openai~=1.2.1 -langchain-google-genai~=4.2.2 +langchain~=1.3.1 +langchain-core~=1.4.0 +langchain-community~=0.4.2 +langchain-anthropic~=1.4.3 +langchain-openai~=1.2.2 +langchain-google-genai~=4.2.3 langchain-deepseek~=1.0.1 -langgraph~=1.1.9 -anthropic>=0.57,<1 -openai~=2.33.0 -google-genai~=1.74.0 +langgraph~=1.2.1 +anthropic~=0.104.1 +openai~=2.38.0 +google-genai~=1.75.0 ddgs~=9.10.0 websocket-client~=1.8.0 lark-oapi~=1.4.23 diff --git a/tests/test_fast_jieba_utils.py b/tests/test_fast_jieba_utils.py new file mode 100644 index 00000000..1568b223 --- /dev/null +++ b/tests/test_fast_jieba_utils.py @@ -0,0 +1,9 @@ +from app.utils.jieba import cut + + +def test_cut_accepts_legacy_hmm_argument(): + """验证兼容封装仍支持旧 jieba.cut 的 HMM 参数名。""" + words = cut("台湾后台测试", HMM=False) + + assert "".join(words) == "台湾后台测试" + assert "后台" in words diff --git a/tests/test_feishu.py b/tests/test_feishu.py index 51058830..caae6812 100644 --- a/tests/test_feishu.py +++ b/tests/test_feishu.py @@ -10,7 +10,6 @@ from unittest.mock import ANY, MagicMock, patch sys.modules.setdefault("psutil", ModuleType("psutil")) sys.modules.setdefault("cn2an", ModuleType("cn2an")) sys.modules.setdefault("dateparser", ModuleType("dateparser")) -sys.modules.setdefault("zhconv", ModuleType("zhconv")) if "Pinyin2Hanzi" not in sys.modules: pinyin_module = ModuleType("Pinyin2Hanzi")