mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 05:52:43 +08:00
refactor: reorganize interaction chain
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.interaction import (
|
||||
from app.helper.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,9 +24,9 @@ from app.schemas.types import (
|
||||
ScrapingPolicy,
|
||||
SystemConfigKey,
|
||||
)
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
recognize_lock = Lock()
|
||||
@@ -44,10 +44,10 @@ class ScrapingOption:
|
||||
policy: ScrapingPolicy = ScrapingPolicy.MISSINGONLY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Union[str, ScrapingTarget],
|
||||
metadata: Union[str, ScrapingMetadata],
|
||||
value: Union[ScrapingPolicy, bool, str],
|
||||
self,
|
||||
type: Union[str, ScrapingTarget],
|
||||
metadata: Union[str, ScrapingMetadata],
|
||||
value: Union[ScrapingPolicy, bool, str],
|
||||
):
|
||||
if isinstance(type, ScrapingTarget):
|
||||
self.type = type
|
||||
@@ -105,7 +105,7 @@ class ScrapingConfig:
|
||||
self._policies[tuple(items)] = ScrapingOption(*items, value)
|
||||
|
||||
def option(
|
||||
self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
|
||||
self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
|
||||
) -> ScrapingOption:
|
||||
|
||||
if isinstance(item, ScrapingTarget):
|
||||
@@ -173,11 +173,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
def on_config_changed(self):
|
||||
self.scraping_policies = ScrapingConfig.from_system_config()
|
||||
|
||||
@staticmethod
|
||||
def _should_scrape(
|
||||
self,
|
||||
scraping_option: ScrapingOption,
|
||||
file_exists: bool,
|
||||
global_overwrite: bool = False,
|
||||
scraping_option: ScrapingOption,
|
||||
file_exists: bool,
|
||||
global_overwrite: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
判断是否应该执行刮削操作
|
||||
@@ -211,7 +211,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return False
|
||||
|
||||
def _save_file(
|
||||
self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
|
||||
self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
|
||||
):
|
||||
"""
|
||||
保存或上传文件
|
||||
@@ -224,7 +224,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return
|
||||
# 使用tempfile创建临时文件
|
||||
with NamedTemporaryFile(
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 写入内容
|
||||
@@ -248,7 +248,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.warn(f"文件保存失败:{path}")
|
||||
|
||||
def _download_and_save_image(
|
||||
self, fileitem: schemas.FileItem, path: Path, url: str
|
||||
self, fileitem: schemas.FileItem, path: Path, url: str
|
||||
):
|
||||
"""
|
||||
流式下载图片并保存到文件
|
||||
@@ -268,7 +268,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
if r and r.status_code == 200:
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 流式写入文件
|
||||
@@ -295,12 +295,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.error(f"{url} 图片下载失败:{str(err)}!")
|
||||
|
||||
def _get_target_fileitem_and_path(
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
item_type: ScrapingTarget,
|
||||
metadata_type: ScrapingMetadata,
|
||||
filename_hint: Optional[str] = None,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
item_type: ScrapingTarget,
|
||||
metadata_type: ScrapingMetadata,
|
||||
filename_hint: Optional[str] = None,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
) -> Tuple[schemas.FileItem, Optional[Path]]:
|
||||
"""
|
||||
根据当前上下文、刮削项类型和元数据类型生成目标 FileItem 和 Path
|
||||
@@ -318,8 +318,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 电影文件NFO: 放在电影文件同级目录,名称与电影文件主体一致,后缀.nfo
|
||||
final_filename = f"{target_dir_path.stem}.nfo"
|
||||
target_dir_item = (
|
||||
parent_fileitem
|
||||
or self.storagechain.get_parent_item(current_fileitem)
|
||||
parent_fileitem
|
||||
or self.storagechain.get_parent_item(current_fileitem)
|
||||
)
|
||||
if not target_dir_item:
|
||||
logger.error(
|
||||
@@ -354,8 +354,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 图片通常是放在当前目录 (current_fileitem) 下
|
||||
# 如果是 EPISODE 类型的图片(如thumb),通常也是放在文件同级目录,文件名与视频文件一致
|
||||
elif (
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
and item_type == ScrapingTarget.EPISODE
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
and item_type == ScrapingTarget.EPISODE
|
||||
):
|
||||
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
|
||||
final_filename = f"{target_dir_path.stem}{hint_ext}"
|
||||
@@ -380,11 +380,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return target_dir_item, target_full_path
|
||||
|
||||
def metadata_nfo(
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
season: Optional[int] = None,
|
||||
episode: Optional[int] = None,
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
season: Optional[int] = None,
|
||||
episode: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取NFO文件内容文本
|
||||
@@ -402,8 +402,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
episode=episode,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def select_recognize_source(
|
||||
self, log_name: str, log_context: str, native_fn, plugin_fn
|
||||
log_name: str, log_context: str, native_fn, plugin_fn
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
选择识别模式,插件优先或原生优先
|
||||
@@ -436,7 +437,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return mediainfo
|
||||
|
||||
def recognize_by_meta(
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
根据主副标题识别媒体信息
|
||||
@@ -513,7 +514,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return self.recognize_media(meta=org_meta)
|
||||
|
||||
def recognize_by_path(
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
) -> Optional[Context]:
|
||||
"""
|
||||
根据文件路径识别媒体信息
|
||||
@@ -577,7 +578,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return meta, medias
|
||||
|
||||
def get_tmdbinfo_by_doubanid(
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据豆瓣ID获取TMDB信息
|
||||
@@ -648,7 +649,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
def get_doubaninfo_by_tmdbid(
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据TMDBID获取豆瓣信息
|
||||
@@ -752,8 +753,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 收集从根目录到文件的所有父目录
|
||||
current_path = sub_path.parent
|
||||
while (
|
||||
current_path != root_path
|
||||
and current_path.is_relative_to(root_path)
|
||||
current_path != root_path
|
||||
and current_path.is_relative_to(root_path)
|
||||
):
|
||||
all_dirs.add(current_path)
|
||||
current_path = current_path.parent
|
||||
@@ -805,15 +806,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _scrape_nfo_generic(
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
NFO 刮削
|
||||
@@ -859,14 +860,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.warn(f"{nfo_path.name} NFO 文件生成失败!")
|
||||
|
||||
def _scrape_images_generic(
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
图片刮削
|
||||
@@ -906,14 +907,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 判断是否匹配当前刮削的季号
|
||||
if item_type == ScrapingTarget.TV and image_name.lower().startswith(
|
||||
"season"
|
||||
"season"
|
||||
):
|
||||
logger.info(f"当前为电视剧根目录刮削,跳过季图片:{image_name}")
|
||||
continue
|
||||
if (
|
||||
item_type == ScrapingTarget.SEASON
|
||||
and season_number is not None
|
||||
and image_name.lower().startswith("season")
|
||||
item_type == ScrapingTarget.SEASON
|
||||
and season_number is not None
|
||||
and image_name.lower().startswith("season")
|
||||
):
|
||||
# 检查是否只下载当前刮削季的图片
|
||||
image_season_str = (
|
||||
@@ -921,7 +922,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
if image_season_str is not None and image_season_str != str(
|
||||
season_number
|
||||
season_number
|
||||
).rjust(2, "0"):
|
||||
logger.info(
|
||||
f"当前刮削季为:{season_number},跳过非本季图片:{image_name}"
|
||||
@@ -956,14 +957,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def scrape_metadata(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase = None,
|
||||
mediainfo: MediaInfo = None,
|
||||
init_folder: bool = True,
|
||||
parent: schemas.FileItem = None,
|
||||
overwrite: bool = False,
|
||||
recursive: bool = True,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase = None,
|
||||
mediainfo: MediaInfo = None,
|
||||
init_folder: bool = True,
|
||||
parent: schemas.FileItem = None,
|
||||
overwrite: bool = False,
|
||||
recursive: bool = True,
|
||||
):
|
||||
"""
|
||||
手动刮削媒体信息
|
||||
@@ -982,7 +983,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 当前文件路径
|
||||
filepath = Path(fileitem.path)
|
||||
if fileitem.type == "file" and (
|
||||
not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
|
||||
not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
|
||||
):
|
||||
return
|
||||
|
||||
@@ -1022,14 +1023,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.info(f"{filepath.name} 刮削完成")
|
||||
|
||||
def _handle_movie_scraping(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电影刮削
|
||||
@@ -1051,20 +1052,18 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=init_folder,
|
||||
parent=parent,
|
||||
overwrite=overwrite,
|
||||
recursive=recursive,
|
||||
)
|
||||
|
||||
def _handle_movie_directory(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电影目录刮削
|
||||
@@ -1105,14 +1104,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _handle_tv_scraping(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电视剧刮削
|
||||
@@ -1142,12 +1141,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _handle_tv_episode_file(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
):
|
||||
"""
|
||||
处理电视剧集文件刮削
|
||||
@@ -1191,15 +1190,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _handle_tv_directory(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电视剧目录刮削
|
||||
@@ -1209,9 +1208,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
files = self.storagechain.list_files(fileitem=fileitem) or []
|
||||
for file in files:
|
||||
if (
|
||||
file.type == "dir"
|
||||
and file.name not in settings.RENAME_FORMAT_S0_NAMES
|
||||
and MetaInfo(file.name).begin_season is None
|
||||
file.type == "dir"
|
||||
and file.name not in settings.RENAME_FORMAT_S0_NAMES
|
||||
and MetaInfo(file.name).begin_season is None
|
||||
):
|
||||
# 电视剧不处理非季子目录
|
||||
continue
|
||||
@@ -1235,13 +1234,13 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _initialize_tv_directory_metadata(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
):
|
||||
"""
|
||||
初始化电视剧目录元数据(识别季号并刮削)
|
||||
@@ -1296,8 +1295,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
else:
|
||||
logger.warn("无法识别元数据,跳过")
|
||||
|
||||
@staticmethod
|
||||
async def async_select_recognize_source(
|
||||
self, log_name: str, log_context: str, native_fn, plugin_fn
|
||||
log_name: str, log_context: str, native_fn, plugin_fn
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
选择识别模式,插件优先或原生优先(异步版本)
|
||||
@@ -1330,7 +1330,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return mediainfo
|
||||
|
||||
async def async_recognize_by_meta(
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
根据主副标题识别媒体信息(异步版本)
|
||||
@@ -1366,7 +1366,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return mediainfo
|
||||
|
||||
async def async_recognize_help(
|
||||
self, title: str, org_meta: MetaBase
|
||||
self, title: str, org_meta: MetaBase
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
请求辅助识别,返回媒体信息(异步版本)
|
||||
@@ -1417,7 +1417,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return await self.async_recognize_media(meta=org_meta)
|
||||
|
||||
async def async_recognize_by_path(
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
) -> Optional[Context]:
|
||||
"""
|
||||
根据文件路径识别媒体信息(异步版本)
|
||||
@@ -1455,7 +1455,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return Context(meta_info=file_meta, media_info=mediainfo)
|
||||
|
||||
async def async_search(
|
||||
self, title: str
|
||||
self, title: str
|
||||
) -> Tuple[Optional[MetaBase], List[MediaInfo]]:
|
||||
"""
|
||||
搜索媒体/人物信息(异步版本)
|
||||
@@ -1502,7 +1502,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
@staticmethod
|
||||
def _extract_year_from_tmdb(
|
||||
tmdbinfo: dict, season: Optional[int] = None
|
||||
tmdbinfo: dict, season: Optional[int] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从TMDB信息中提取年份
|
||||
@@ -1522,11 +1522,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return year
|
||||
|
||||
def _match_tmdb_with_names(
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
使用名称列表匹配TMDB信息
|
||||
@@ -1540,11 +1540,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
async def _async_match_tmdb_with_names(
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
使用名称列表匹配TMDB信息(异步版本)
|
||||
@@ -1558,7 +1558,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
async def async_get_tmdbinfo_by_doubanid(
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据豆瓣ID获取TMDB信息(异步版本)
|
||||
@@ -1629,7 +1629,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
async def async_get_doubaninfo_by_tmdbid(
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据TMDBID获取豆瓣信息(异步版本)
|
||||
|
||||
1148
app/chain/message.py
1148
app/chain/message.py
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ from urllib.parse import urljoin
|
||||
from lxml import etree
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.helper.slash import (
|
||||
from app.helper.interaction import (
|
||||
SlashInteractionManager,
|
||||
build_navigation_buttons,
|
||||
format_markdown_table,
|
||||
@@ -1060,8 +1060,9 @@ class SiteChain(ChainBase):
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_site_list(
|
||||
self, site_list: List[Site], channel: Optional[MessageChannel]
|
||||
site_list: List[Site], channel: Optional[MessageChannel]
|
||||
) -> str:
|
||||
"""
|
||||
根据渠道能力格式化站点列表。
|
||||
|
||||
@@ -1,145 +1,18 @@
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.helper.slash import (
|
||||
from app.helper.interaction import (
|
||||
build_navigation_buttons,
|
||||
page_items,
|
||||
supports_interaction_buttons,
|
||||
update_or_post_message,
|
||||
update_or_post_message, skills_interaction_manager, PendingSkillsInteraction,
|
||||
)
|
||||
from app.helper.skill import SkillHelper, SkillInfo
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingSkillsInteraction:
|
||||
"""
|
||||
记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
view: str = "root"
|
||||
local_page: int = 0
|
||||
market_page: int = 0
|
||||
market_query: str = ""
|
||||
awaiting_input: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SkillsInteractionManager:
|
||||
"""
|
||||
管理用户当前的技能交互状态。
|
||||
|
||||
每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingSkillsInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
"""
|
||||
清理超时会话,避免按钮回调无限积累。
|
||||
"""
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
) -> PendingSkillsInteraction:
|
||||
"""
|
||||
为用户创建新会话,并替换掉旧的技能交互状态。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingSkillsInteraction(
|
||||
request_id=request_id,
|
||||
user_id=user_key,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
self._by_id[request_id] = request
|
||||
self._by_user[user_key] = request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按用户获取当前有效会话,供纯文本回复路由使用。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按请求 ID 获取会话,并校验会话归属用户。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
"""
|
||||
主动结束会话,释放用户和请求 ID 的双向索引。
|
||||
"""
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有会话,主要用于测试场景。
|
||||
"""
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
skills_interaction_manager = SkillsInteractionManager()
|
||||
|
||||
|
||||
class SkillsChain(ChainBase):
|
||||
"""
|
||||
处理 /skills 指令、按钮回调和文本式技能管理交互。
|
||||
@@ -153,11 +26,11 @@ class SkillsChain(ChainBase):
|
||||
self.skillhelper = SkillHelper()
|
||||
|
||||
def remote_manage(
|
||||
self,
|
||||
arg_str: str,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int],
|
||||
source: Optional[str] = None,
|
||||
self,
|
||||
arg_str: str,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int],
|
||||
source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
/skills 入口。创建新会话并渲染首屏菜单。
|
||||
@@ -205,14 +78,14 @@ class SkillsChain(ChainBase):
|
||||
return request_id, action, index
|
||||
|
||||
def handle_callback_interaction(
|
||||
self,
|
||||
callback_data: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
callback_data: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理按钮交互,并在同一条消息上刷新当前视图。
|
||||
@@ -364,12 +237,12 @@ class SkillsChain(ChainBase):
|
||||
return True
|
||||
|
||||
def handle_text_interaction(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
text: str,
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
处理不支持按钮渠道上的文本指令,也兼容用户直接回复文字操作。
|
||||
@@ -660,42 +533,42 @@ class SkillsChain(ChainBase):
|
||||
return True
|
||||
|
||||
def _install_market_skill(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
按当前市场页的可见序号安装技能,避免跨页序号歧义。
|
||||
"""
|
||||
market_skills = self._get_market_skills(request=request)
|
||||
page_items, page, _ = self._page_items(
|
||||
items, page, _ = self._page_items(
|
||||
items=market_skills,
|
||||
page=request.market_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
)
|
||||
request.market_page = page
|
||||
if page_index < 1 or page_index > len(page_items):
|
||||
if page_index < 1 or page_index > len(items):
|
||||
return False, "安装序号无效"
|
||||
return self.skillhelper.install_market_skill(page_items[page_index - 1])
|
||||
return self.skillhelper.install_market_skill(items[page_index - 1])
|
||||
|
||||
def _remove_local_skill(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
按当前已安装页的可见序号删除技能,并拦截内置技能。
|
||||
"""
|
||||
local_skills = self.skillhelper.list_local_skills()
|
||||
page_items, page, _ = self._page_items(
|
||||
items, page, _ = self._page_items(
|
||||
items=local_skills,
|
||||
page=request.local_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
)
|
||||
request.local_page = page
|
||||
if page_index < 1 or page_index > len(page_items):
|
||||
if page_index < 1 or page_index > len(items):
|
||||
return False, "删除序号无效"
|
||||
target = page_items[page_index - 1]
|
||||
target = items[page_index - 1]
|
||||
if not target.removable:
|
||||
return False, f"技能 {target.id} 是内置技能,不能删除"
|
||||
return self.skillhelper.remove_local_skill(target.id)
|
||||
@@ -713,15 +586,15 @@ class SkillsChain(ChainBase):
|
||||
return self.skillhelper.remove_custom_market_source(target.source)
|
||||
|
||||
def _render_interaction(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
force_market_refresh: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
根据当前视图生成内容,并选择编辑原消息或发送新消息。
|
||||
@@ -758,9 +631,9 @@ class SkillsChain(ChainBase):
|
||||
)
|
||||
|
||||
def _build_root_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建根菜单视图,汇总本地技能和市场概览。
|
||||
@@ -809,14 +682,14 @@ class SkillsChain(ChainBase):
|
||||
return "技能管理", "\n".join(text_lines), buttons
|
||||
|
||||
def _build_installed_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction
|
||||
self,
|
||||
request: PendingSkillsInteraction
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建已安装技能视图,列出来源和可删除状态。
|
||||
"""
|
||||
local_skills = self.skillhelper.list_local_skills()
|
||||
page_items, page, total_pages = self._page_items(
|
||||
items, page, total_pages = self._page_items(
|
||||
items=local_skills,
|
||||
page=request.local_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
@@ -824,11 +697,11 @@ class SkillsChain(ChainBase):
|
||||
request.local_page = page
|
||||
|
||||
text_lines = [f"第 {page + 1}/{total_pages} 页,共 {len(local_skills)} 个技能"]
|
||||
if not page_items:
|
||||
if not items:
|
||||
text_lines.append("")
|
||||
text_lines.append("当前没有已安装技能")
|
||||
else:
|
||||
for index, skill in enumerate(page_items, start=1):
|
||||
for index, skill in enumerate(items, start=1):
|
||||
action = "可删除" if skill.removable else "内置不可删"
|
||||
text_lines.extend(
|
||||
[
|
||||
@@ -869,9 +742,9 @@ class SkillsChain(ChainBase):
|
||||
return "已安装技能", "\n".join(text_lines), buttons
|
||||
|
||||
def _build_market_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建技能市场视图,仅展示尚未安装的技能。
|
||||
@@ -880,7 +753,7 @@ class SkillsChain(ChainBase):
|
||||
request=request,
|
||||
force_market_refresh=force_market_refresh,
|
||||
)
|
||||
page_items, page, total_pages = self._page_items(
|
||||
items, page, total_pages = self._page_items(
|
||||
items=market_skills,
|
||||
page=request.market_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
@@ -897,14 +770,14 @@ class SkillsChain(ChainBase):
|
||||
"搜索输入中:直接回复关键词即可筛选市场技能,回复 `取消` 结束输入。",
|
||||
]
|
||||
)
|
||||
if not page_items:
|
||||
if not items:
|
||||
text_lines.append("")
|
||||
if request.market_query:
|
||||
text_lines.append("当前搜索没有匹配的市场技能")
|
||||
else:
|
||||
text_lines.append("当前没有可安装的市场技能")
|
||||
else:
|
||||
for index, skill in enumerate(page_items, start=1):
|
||||
for index, skill in enumerate(items, start=1):
|
||||
text_lines.extend(
|
||||
[
|
||||
"",
|
||||
@@ -970,8 +843,8 @@ class SkillsChain(ChainBase):
|
||||
return "技能市场", "\n".join(text_lines), buttons
|
||||
|
||||
def _build_sources_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建技能源管理视图,提供自定义 GitHub 源的增删入口。
|
||||
@@ -1052,9 +925,9 @@ class SkillsChain(ChainBase):
|
||||
|
||||
@staticmethod
|
||||
def _page_items(
|
||||
items: List[SkillInfo],
|
||||
page: int,
|
||||
page_size: int,
|
||||
items: List[SkillInfo],
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> Tuple[List[SkillInfo], int, int]:
|
||||
"""
|
||||
返回当前页的数据,并把页码钳制到有效范围内。
|
||||
@@ -1080,9 +953,9 @@ class SkillsChain(ChainBase):
|
||||
|
||||
@staticmethod
|
||||
def _navigation_buttons(
|
||||
request: PendingSkillsInteraction,
|
||||
page: int,
|
||||
total_pages: int,
|
||||
request: PendingSkillsInteraction,
|
||||
page: int,
|
||||
total_pages: int,
|
||||
) -> List[List[dict]]:
|
||||
"""
|
||||
为分页视图生成上一页和下一页按钮。
|
||||
@@ -1095,16 +968,16 @@ class SkillsChain(ChainBase):
|
||||
)
|
||||
|
||||
def _update_or_post_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
title: str,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
title: str,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
优先编辑原消息,编辑失败时再回退为发送新消息。
|
||||
@@ -1136,9 +1009,9 @@ class SkillsChain(ChainBase):
|
||||
return "请输入 1、2、3、搜索 <关键词>、刷新 或 退出"
|
||||
|
||||
def _get_market_skills(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
) -> List[SkillInfo]:
|
||||
"""
|
||||
获取当前 /skills 会话可见的市场技能,并应用搜索词过滤。
|
||||
@@ -1183,8 +1056,8 @@ class SkillsChain(ChainBase):
|
||||
|
||||
@staticmethod
|
||||
def _apply_market_search(
|
||||
request: PendingSkillsInteraction,
|
||||
query: str,
|
||||
request: PendingSkillsInteraction,
|
||||
query: str,
|
||||
) -> None:
|
||||
"""
|
||||
将会话切到市场搜索结果视图,并重置分页状态。
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -11,7 +12,7 @@ from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.helper.slash import (
|
||||
from app.helper.interaction import (
|
||||
SlashInteractionManager,
|
||||
build_navigation_buttons,
|
||||
format_markdown_table,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .cloudflare import under_challenge
|
||||
|
||||
626
app/helper/interaction.py
Normal file
626
app/helper/interaction.py
Normal file
@@ -0,0 +1,626 @@
|
||||
import math
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.schemas import Notification
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingSlashInteraction:
|
||||
"""
|
||||
通用 slash 命令交互上下文。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
command: str
|
||||
page: int = 0
|
||||
awaiting_input: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SlashInteractionManager:
|
||||
"""
|
||||
管理单个 slash 命令的交互会话。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingSlashInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self) -> None:
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
command: str,
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
) -> PendingSlashInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
request = PendingSlashInteraction(
|
||||
request_id=uuid.uuid4().hex[:12],
|
||||
user_id=user_key,
|
||||
command=command,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
self._by_id[request.request_id] = request
|
||||
self._by_user[user_key] = request.request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingSlashInteraction]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingSlashInteraction]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
|
||||
"""
|
||||
渠道同时支持按钮和回调时,优先使用按钮交互。
|
||||
"""
|
||||
return bool(
|
||||
channel
|
||||
and ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
)
|
||||
|
||||
|
||||
def supports_markdown(channel: Optional[MessageChannel]) -> bool:
|
||||
"""
|
||||
仅在支持 Markdown 的渠道上输出 Markdown 内容。
|
||||
"""
|
||||
return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
|
||||
|
||||
|
||||
def page_items(
|
||||
items: Sequence[Any],
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> Tuple[List[Any], int, int]:
|
||||
"""
|
||||
对列表做分页并规范化页码。
|
||||
"""
|
||||
total = len(items)
|
||||
if total == 0:
|
||||
return [], 0, 1
|
||||
total_pages = max(1, math.ceil(total / max(1, page_size)))
|
||||
page = min(max(0, page), total_pages - 1)
|
||||
start = page * page_size
|
||||
end = start + page_size
|
||||
return list(items[start:end]), page, total_pages
|
||||
|
||||
|
||||
def build_navigation_buttons(
|
||||
prefix: str,
|
||||
request: Any,
|
||||
page: int,
|
||||
total_pages: int,
|
||||
) -> List[List[dict]]:
|
||||
"""
|
||||
构造标准上一页/下一页按钮。
|
||||
"""
|
||||
buttons = []
|
||||
nav_row = []
|
||||
if page > 0:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "⬅️ 上一页",
|
||||
"callback_data": f"{prefix}:{request.request_id}:page-prev",
|
||||
}
|
||||
)
|
||||
if page < total_pages - 1:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "下一页 ➡️",
|
||||
"callback_data": f"{prefix}:{request.request_id}:page-next",
|
||||
}
|
||||
)
|
||||
if nav_row:
|
||||
buttons.append(nav_row)
|
||||
return buttons
|
||||
|
||||
|
||||
def update_or_post_message(
|
||||
chain,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
title: str,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
优先编辑原消息,失败时回退为发送新消息。
|
||||
"""
|
||||
if (
|
||||
original_message_id
|
||||
and original_chat_id
|
||||
and ChannelCapabilityManager.supports_editing(channel)
|
||||
):
|
||||
edited = chain.edit_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
if edited:
|
||||
return
|
||||
|
||||
chain.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def escape_markdown_table_cell(value: object) -> str:
|
||||
"""
|
||||
最小化转义 Markdown 表格中的特殊字符。
|
||||
"""
|
||||
text = str(value or "").replace("\n", "<br>")
|
||||
return text.replace("|", "\\|")
|
||||
|
||||
|
||||
def format_markdown_table(
|
||||
headers: Sequence[str],
|
||||
rows: Sequence[Sequence[object]],
|
||||
) -> str:
|
||||
"""
|
||||
生成 Markdown 表格文本。
|
||||
"""
|
||||
header_line = (
|
||||
"| "
|
||||
+ " | ".join(escape_markdown_table_cell(item) for item in headers)
|
||||
+ " |"
|
||||
)
|
||||
separator_line = "| " + " | ".join("---" for _ in headers) + " |"
|
||||
data_lines = [
|
||||
"| "
|
||||
+ " | ".join(escape_markdown_table_cell(item) for item in row)
|
||||
+ " |"
|
||||
for row in rows
|
||||
]
|
||||
return "\n".join([header_line, separator_line, *data_lines])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingMediaInteraction:
|
||||
"""
|
||||
记录一次搜索/下载/订阅交互的当前上下文。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
action: str
|
||||
keyword: str
|
||||
phase: str = "media"
|
||||
page: int = 0
|
||||
title: str = ""
|
||||
meta: Optional[MetaBase] = None
|
||||
current_media: Optional[MediaInfo] = None
|
||||
items: List[Any] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class MediaInteractionManager:
|
||||
"""
|
||||
管理用户当前激活的媒体交互状态。
|
||||
|
||||
每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingMediaInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self) -> None:
|
||||
"""
|
||||
清理超时会话,避免内存中残留旧交互状态。
|
||||
"""
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
action: str,
|
||||
keyword: str,
|
||||
title: str = "",
|
||||
meta: Optional[MetaBase] = None,
|
||||
items: Optional[List[Any]] = None,
|
||||
) -> PendingMediaInteraction:
|
||||
"""
|
||||
为用户创建新的交互状态,并替换旧会话。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
|
||||
request = PendingMediaInteraction(
|
||||
request_id=uuid.uuid4().hex[:12],
|
||||
user_id=user_key,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
action=action,
|
||||
keyword=keyword,
|
||||
title=title,
|
||||
meta=meta,
|
||||
items=list(items or []),
|
||||
)
|
||||
self._by_id[request.request_id] = request
|
||||
self._by_user[user_key] = request.request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingMediaInteraction]:
|
||||
"""
|
||||
按用户读取当前会话,供文本回复和旧按钮兼容使用。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingMediaInteraction]:
|
||||
"""
|
||||
按请求 ID 读取会话,并校验用户归属。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
"""
|
||||
主动结束一条会话。
|
||||
"""
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
清空所有交互状态,主要用于测试。
|
||||
"""
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
media_interaction_manager = MediaInteractionManager()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInteractionOption:
|
||||
"""
|
||||
Agent 交互选项。
|
||||
"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentInteraction:
|
||||
"""
|
||||
待处理的 Agent 客户端交互请求。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentInteractionManager:
|
||||
"""
|
||||
管理 Agent 发起的客户端交互请求。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self) -> None:
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
"""
|
||||
创建一条待用户确认的 Agent 交互请求。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
"""
|
||||
消费一条 Agent 交互请求,并返回选中的选项。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
return None
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
清空所有 Agent 交互请求。
|
||||
"""
|
||||
with self._lock:
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingSkillsInteraction:
|
||||
"""
|
||||
记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
view: str = "root"
|
||||
local_page: int = 0
|
||||
market_page: int = 0
|
||||
market_query: str = ""
|
||||
awaiting_input: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SkillsInteractionManager:
|
||||
"""
|
||||
管理用户当前的技能交互状态。
|
||||
|
||||
每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingSkillsInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
"""
|
||||
清理超时会话,避免按钮回调无限积累。
|
||||
"""
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
) -> PendingSkillsInteraction:
|
||||
"""
|
||||
为用户创建新会话,并替换掉旧的技能交互状态。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingSkillsInteraction(
|
||||
request_id=request_id,
|
||||
user_id=user_key,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
self._by_id[request_id] = request
|
||||
self._by_user[user_key] = request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按用户获取当前有效会话,供纯文本回复路由使用。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按请求 ID 获取会话,并校验会话归属用户。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
"""
|
||||
主动结束会话,释放用户和请求 ID 的双向索引。
|
||||
"""
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有会话,主要用于测试场景。
|
||||
"""
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
skills_interaction_manager = SkillsInteractionManager()
|
||||
@@ -1,244 +0,0 @@
|
||||
import math
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from app.schemas import Notification
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingSlashInteraction:
|
||||
"""
|
||||
通用 slash 命令交互上下文。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
command: str
|
||||
page: int = 0
|
||||
awaiting_input: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SlashInteractionManager:
|
||||
"""
|
||||
管理单个 slash 命令的交互会话。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingSlashInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self) -> None:
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
command: str,
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
) -> PendingSlashInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
request = PendingSlashInteraction(
|
||||
request_id=uuid.uuid4().hex[:12],
|
||||
user_id=user_key,
|
||||
command=command,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
self._by_id[request.request_id] = request
|
||||
self._by_user[user_key] = request.request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingSlashInteraction]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingSlashInteraction]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
|
||||
"""
|
||||
渠道同时支持按钮和回调时,优先使用按钮交互。
|
||||
"""
|
||||
return bool(
|
||||
channel
|
||||
and ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
)
|
||||
|
||||
|
||||
def supports_markdown(channel: Optional[MessageChannel]) -> bool:
|
||||
"""
|
||||
仅在支持 Markdown 的渠道上输出 Markdown 内容。
|
||||
"""
|
||||
return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
|
||||
|
||||
|
||||
def page_items(
|
||||
items: Sequence,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> Tuple[List, int, int]:
|
||||
"""
|
||||
对列表做分页并规范化页码。
|
||||
"""
|
||||
total = len(items)
|
||||
if total == 0:
|
||||
return [], 0, 1
|
||||
total_pages = max(1, math.ceil(total / max(1, page_size)))
|
||||
page = min(max(0, page), total_pages - 1)
|
||||
start = page * page_size
|
||||
end = start + page_size
|
||||
return list(items[start:end]), page, total_pages
|
||||
|
||||
|
||||
def build_navigation_buttons(
|
||||
prefix: str,
|
||||
request: PendingSlashInteraction,
|
||||
page: int,
|
||||
total_pages: int,
|
||||
) -> List[List[dict]]:
|
||||
"""
|
||||
构造标准上一页/下一页按钮。
|
||||
"""
|
||||
buttons = []
|
||||
nav_row = []
|
||||
if page > 0:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "⬅️ 上一页",
|
||||
"callback_data": f"{prefix}:{request.request_id}:page-prev",
|
||||
}
|
||||
)
|
||||
if page < total_pages - 1:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "下一页 ➡️",
|
||||
"callback_data": f"{prefix}:{request.request_id}:page-next",
|
||||
}
|
||||
)
|
||||
if nav_row:
|
||||
buttons.append(nav_row)
|
||||
return buttons
|
||||
|
||||
|
||||
def update_or_post_message(
|
||||
chain,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
title: str,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
优先编辑原消息,失败时回退为发送新消息。
|
||||
"""
|
||||
if (
|
||||
original_message_id
|
||||
and original_chat_id
|
||||
and ChannelCapabilityManager.supports_editing(channel)
|
||||
):
|
||||
edited = chain.edit_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
if edited:
|
||||
return
|
||||
|
||||
chain.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def escape_markdown_table_cell(value: object) -> str:
|
||||
"""
|
||||
最小化转义 Markdown 表格中的特殊字符。
|
||||
"""
|
||||
text = str(value or "").replace("\n", "<br>")
|
||||
text = text.replace("|", "\\|")
|
||||
return text
|
||||
|
||||
|
||||
def format_markdown_table(headers: Sequence[str], rows: Sequence[Sequence[object]]) -> str:
|
||||
"""
|
||||
生成 Markdown 表格文本。
|
||||
"""
|
||||
header_line = "| " + " | ".join(escape_markdown_table_cell(item) for item in headers) + " |"
|
||||
separator_line = "| " + " | ".join("---" for _ in headers) + " |"
|
||||
data_lines = [
|
||||
"| "
|
||||
+ " | ".join(escape_markdown_table_cell(item) for item in row)
|
||||
+ " |"
|
||||
for row in rows
|
||||
]
|
||||
return "\n".join([header_line, separator_line, *data_lines])
|
||||
@@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import (
|
||||
AskUserChoiceTool,
|
||||
UserChoiceOptionInput,
|
||||
)
|
||||
from app.chain.interaction import (
|
||||
from app.helper.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
|
||||
setattr(sys.modules["transmission_rpc"], "File", object)
|
||||
sys.modules.setdefault("psutil", ModuleType("psutil"))
|
||||
|
||||
from app.chain.interaction import MediaInteractionChain, media_interaction_manager
|
||||
from app.chain.media import MediaChain, media_interaction_manager
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
@@ -43,7 +43,7 @@ class TestMediaInteraction(unittest.TestCase):
|
||||
self.assertIsNotNone(request)
|
||||
|
||||
with patch.object(chain, "_record_user_message"), patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_text_interaction",
|
||||
"app.chain.message.MediaChain.handle_text_interaction",
|
||||
return_value=True,
|
||||
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
|
||||
chain.handle_message(
|
||||
@@ -72,7 +72,7 @@ class TestMediaInteraction(unittest.TestCase):
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_callback_interaction",
|
||||
"app.chain.message.MediaChain.handle_callback_interaction",
|
||||
return_value=True,
|
||||
) as handle_callback:
|
||||
chain._handle_callback(
|
||||
@@ -86,7 +86,7 @@ class TestMediaInteraction(unittest.TestCase):
|
||||
handle_callback.assert_called_once()
|
||||
|
||||
def test_media_interaction_starts_search_and_posts_media_list(self):
|
||||
chain = MediaInteractionChain()
|
||||
chain = MediaChain()
|
||||
meta = self._build_meta("星际穿越")
|
||||
medias = [
|
||||
MediaInfo(title="星际穿越", year="2014"),
|
||||
@@ -94,7 +94,7 @@ class TestMediaInteraction(unittest.TestCase):
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.chain.interaction.MediaChain.search",
|
||||
"app.chain.media.MediaChain.search",
|
||||
return_value=(meta, medias),
|
||||
), patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_text_interaction(
|
||||
@@ -119,7 +119,7 @@ class TestMediaInteraction(unittest.TestCase):
|
||||
self.assertEqual(len(request.items), 2)
|
||||
|
||||
def test_media_interaction_legacy_page_callback_updates_existing_request(self):
|
||||
chain = MediaInteractionChain()
|
||||
chain = MediaChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram,
|
||||
|
||||
Reference in New Issue
Block a user