refactor: reorganize interaction chain

This commit is contained in:
jxxghp
2026-05-01 09:53:04 +08:00
parent db6dc926cf
commit 4d0a722b09
12 changed files with 1975 additions and 1988 deletions

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.chain.interaction import (
from app.helper.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)

File diff suppressed because it is too large Load Diff

View File

@@ -24,9 +24,9 @@ from app.schemas.types import (
ScrapingPolicy,
SystemConfigKey,
)
from app.utils.http import RequestUtils
from app.utils.mixins import ConfigReloadMixin
from app.utils.singleton import Singleton
from app.utils.http import RequestUtils
from app.utils.string import StringUtils
recognize_lock = Lock()
@@ -44,10 +44,10 @@ class ScrapingOption:
policy: ScrapingPolicy = ScrapingPolicy.MISSINGONLY
def __init__(
self,
type: Union[str, ScrapingTarget],
metadata: Union[str, ScrapingMetadata],
value: Union[ScrapingPolicy, bool, str],
self,
type: Union[str, ScrapingTarget],
metadata: Union[str, ScrapingMetadata],
value: Union[ScrapingPolicy, bool, str],
):
if isinstance(type, ScrapingTarget):
self.type = type
@@ -105,7 +105,7 @@ class ScrapingConfig:
self._policies[tuple(items)] = ScrapingOption(*items, value)
def option(
self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
) -> ScrapingOption:
if isinstance(item, ScrapingTarget):
@@ -173,11 +173,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
def on_config_changed(self):
self.scraping_policies = ScrapingConfig.from_system_config()
@staticmethod
def _should_scrape(
self,
scraping_option: ScrapingOption,
file_exists: bool,
global_overwrite: bool = False,
scraping_option: ScrapingOption,
file_exists: bool,
global_overwrite: bool = False,
) -> bool:
"""
判断是否应该执行刮削操作
@@ -211,7 +211,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return False
def _save_file(
self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
):
"""
保存或上传文件
@@ -224,7 +224,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return
# 使用tempfile创建临时文件
with NamedTemporaryFile(
delete=True, delete_on_close=False, suffix=path.suffix
delete=True, delete_on_close=False, suffix=path.suffix
) as tmp_file:
tmp_file_path = Path(tmp_file.name)
# 写入内容
@@ -248,7 +248,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.warn(f"文件保存失败:{path}")
def _download_and_save_image(
self, fileitem: schemas.FileItem, path: Path, url: str
self, fileitem: schemas.FileItem, path: Path, url: str
):
"""
流式下载图片并保存到文件
@@ -268,7 +268,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
if r and r.status_code == 200:
# 使用tempfile创建临时文件自动删除
with NamedTemporaryFile(
delete=True, delete_on_close=False, suffix=path.suffix
delete=True, delete_on_close=False, suffix=path.suffix
) as tmp_file:
tmp_file_path = Path(tmp_file.name)
# 流式写入文件
@@ -295,12 +295,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.error(f"{url} 图片下载失败:{str(err)}")
def _get_target_fileitem_and_path(
self,
current_fileitem: schemas.FileItem,
item_type: ScrapingTarget,
metadata_type: ScrapingMetadata,
filename_hint: Optional[str] = None,
parent_fileitem: Optional[schemas.FileItem] = None,
self,
current_fileitem: schemas.FileItem,
item_type: ScrapingTarget,
metadata_type: ScrapingMetadata,
filename_hint: Optional[str] = None,
parent_fileitem: Optional[schemas.FileItem] = None,
) -> Tuple[schemas.FileItem, Optional[Path]]:
"""
根据当前上下文、刮削项类型和元数据类型生成目标 FileItem 和 Path
@@ -318,8 +318,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 电影文件NFO: 放在电影文件同级目录,名称与电影文件主体一致,后缀.nfo
final_filename = f"{target_dir_path.stem}.nfo"
target_dir_item = (
parent_fileitem
or self.storagechain.get_parent_item(current_fileitem)
parent_fileitem
or self.storagechain.get_parent_item(current_fileitem)
)
if not target_dir_item:
logger.error(
@@ -354,8 +354,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 图片通常是放在当前目录 (current_fileitem) 下
# 如果是 EPISODE 类型的图片如thumb通常也是放在文件同级目录文件名与视频文件一致
elif (
metadata_type in [ScrapingMetadata.THUMB]
and item_type == ScrapingTarget.EPISODE
metadata_type in [ScrapingMetadata.THUMB]
and item_type == ScrapingTarget.EPISODE
):
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
final_filename = f"{target_dir_path.stem}{hint_ext}"
@@ -380,11 +380,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return target_dir_item, target_full_path
def metadata_nfo(
self,
meta: MetaBase,
mediainfo: MediaInfo,
season: Optional[int] = None,
episode: Optional[int] = None,
self,
meta: MetaBase,
mediainfo: MediaInfo,
season: Optional[int] = None,
episode: Optional[int] = None,
) -> Optional[str]:
"""
获取NFO文件内容文本
@@ -402,8 +402,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
episode=episode,
)
@staticmethod
def select_recognize_source(
self, log_name: str, log_context: str, native_fn, plugin_fn
log_name: str, log_context: str, native_fn, plugin_fn
) -> Optional[MediaInfo]:
"""
选择识别模式,插件优先或原生优先
@@ -436,7 +437,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return mediainfo
def recognize_by_meta(
self, metainfo: MetaBase, episode_group: Optional[str] = None
self, metainfo: MetaBase, episode_group: Optional[str] = None
) -> Optional[MediaInfo]:
"""
根据主副标题识别媒体信息
@@ -513,7 +514,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return self.recognize_media(meta=org_meta)
def recognize_by_path(
self, path: str, episode_group: Optional[str] = None
self, path: str, episode_group: Optional[str] = None
) -> Optional[Context]:
"""
根据文件路径识别媒体信息
@@ -577,7 +578,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return meta, medias
def get_tmdbinfo_by_doubanid(
self, doubanid: str, mtype: MediaType = None
self, doubanid: str, mtype: MediaType = None
) -> Optional[dict]:
"""
根据豆瓣ID获取TMDB信息
@@ -648,7 +649,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
def get_doubaninfo_by_tmdbid(
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
) -> Optional[dict]:
"""
根据TMDBID获取豆瓣信息
@@ -752,8 +753,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 收集从根目录到文件的所有父目录
current_path = sub_path.parent
while (
current_path != root_path
and current_path.is_relative_to(root_path)
current_path != root_path
and current_path.is_relative_to(root_path)
):
all_dirs.add(current_path)
current_path = current_path.parent
@@ -805,15 +806,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _scrape_nfo_generic(
self,
current_fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
item_type: ScrapingTarget,
parent_fileitem: Optional[schemas.FileItem] = None,
overwrite: bool = False,
season_number: Optional[int] = None,
episode_number: Optional[int] = None,
self,
current_fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
item_type: ScrapingTarget,
parent_fileitem: Optional[schemas.FileItem] = None,
overwrite: bool = False,
season_number: Optional[int] = None,
episode_number: Optional[int] = None,
):
"""
NFO 刮削
@@ -859,14 +860,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.warn(f"{nfo_path.name} NFO 文件生成失败!")
def _scrape_images_generic(
self,
current_fileitem: schemas.FileItem,
mediainfo: MediaInfo,
item_type: ScrapingTarget,
parent_fileitem: Optional[schemas.FileItem] = None,
overwrite: bool = False,
season_number: Optional[int] = None,
episode_number: Optional[int] = None,
self,
current_fileitem: schemas.FileItem,
mediainfo: MediaInfo,
item_type: ScrapingTarget,
parent_fileitem: Optional[schemas.FileItem] = None,
overwrite: bool = False,
season_number: Optional[int] = None,
episode_number: Optional[int] = None,
):
"""
图片刮削
@@ -906,14 +907,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 判断是否匹配当前刮削的季号
if item_type == ScrapingTarget.TV and image_name.lower().startswith(
"season"
"season"
):
logger.info(f"当前为电视剧根目录刮削,跳过季图片:{image_name}")
continue
if (
item_type == ScrapingTarget.SEASON
and season_number is not None
and image_name.lower().startswith("season")
item_type == ScrapingTarget.SEASON
and season_number is not None
and image_name.lower().startswith("season")
):
# 检查是否只下载当前刮削季的图片
image_season_str = (
@@ -921,7 +922,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
if image_season_str is not None and image_season_str != str(
season_number
season_number
).rjust(2, "0"):
logger.info(
f"当前刮削季为:{season_number},跳过非本季图片:{image_name}"
@@ -956,14 +957,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def scrape_metadata(
self,
fileitem: schemas.FileItem,
meta: MetaBase = None,
mediainfo: MediaInfo = None,
init_folder: bool = True,
parent: schemas.FileItem = None,
overwrite: bool = False,
recursive: bool = True,
self,
fileitem: schemas.FileItem,
meta: MetaBase = None,
mediainfo: MediaInfo = None,
init_folder: bool = True,
parent: schemas.FileItem = None,
overwrite: bool = False,
recursive: bool = True,
):
"""
手动刮削媒体信息
@@ -982,7 +983,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 当前文件路径
filepath = Path(fileitem.path)
if fileitem.type == "file" and (
not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
):
return
@@ -1022,14 +1023,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
logger.info(f"{filepath.name} 刮削完成")
def _handle_movie_scraping(
self,
fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
parent: schemas.FileItem,
overwrite: bool,
recursive: bool,
self,
fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
parent: schemas.FileItem,
overwrite: bool,
recursive: bool,
):
"""
处理电影刮削
@@ -1051,20 +1052,18 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
meta=meta,
mediainfo=mediainfo,
init_folder=init_folder,
parent=parent,
overwrite=overwrite,
recursive=recursive,
)
def _handle_movie_directory(
self,
fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
parent: schemas.FileItem,
overwrite: bool,
recursive: bool,
self,
fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
overwrite: bool,
recursive: bool,
):
"""
处理电影目录刮削
@@ -1105,14 +1104,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _handle_tv_scraping(
self,
fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
parent: schemas.FileItem,
overwrite: bool,
recursive: bool,
self,
fileitem: schemas.FileItem,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
parent: schemas.FileItem,
overwrite: bool,
recursive: bool,
):
"""
处理电视剧刮削
@@ -1142,12 +1141,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _handle_tv_episode_file(
self,
fileitem: schemas.FileItem,
filepath: Path,
mediainfo: MediaInfo,
parent: schemas.FileItem,
overwrite: bool,
self,
fileitem: schemas.FileItem,
filepath: Path,
mediainfo: MediaInfo,
parent: schemas.FileItem,
overwrite: bool,
):
"""
处理电视剧集文件刮削
@@ -1191,15 +1190,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _handle_tv_directory(
self,
fileitem: schemas.FileItem,
filepath: Path,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
parent: schemas.FileItem,
overwrite: bool,
recursive: bool,
self,
fileitem: schemas.FileItem,
filepath: Path,
meta: MetaBase,
mediainfo: MediaInfo,
init_folder: bool,
parent: schemas.FileItem,
overwrite: bool,
recursive: bool,
):
"""
处理电视剧目录刮削
@@ -1209,9 +1208,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
files = self.storagechain.list_files(fileitem=fileitem) or []
for file in files:
if (
file.type == "dir"
and file.name not in settings.RENAME_FORMAT_S0_NAMES
and MetaInfo(file.name).begin_season is None
file.type == "dir"
and file.name not in settings.RENAME_FORMAT_S0_NAMES
and MetaInfo(file.name).begin_season is None
):
# 电视剧不处理非季子目录
continue
@@ -1235,13 +1234,13 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
)
def _initialize_tv_directory_metadata(
self,
fileitem: schemas.FileItem,
filepath: Path,
meta: MetaBase,
mediainfo: MediaInfo,
parent: schemas.FileItem,
overwrite: bool,
self,
fileitem: schemas.FileItem,
filepath: Path,
meta: MetaBase,
mediainfo: MediaInfo,
parent: schemas.FileItem,
overwrite: bool,
):
"""
初始化电视剧目录元数据(识别季号并刮削)
@@ -1296,8 +1295,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
else:
logger.warn("无法识别元数据,跳过")
@staticmethod
async def async_select_recognize_source(
self, log_name: str, log_context: str, native_fn, plugin_fn
log_name: str, log_context: str, native_fn, plugin_fn
) -> Optional[MediaInfo]:
"""
选择识别模式,插件优先或原生优先(异步版本)
@@ -1330,7 +1330,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return mediainfo
async def async_recognize_by_meta(
self, metainfo: MetaBase, episode_group: Optional[str] = None
self, metainfo: MetaBase, episode_group: Optional[str] = None
) -> Optional[MediaInfo]:
"""
根据主副标题识别媒体信息(异步版本)
@@ -1366,7 +1366,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return mediainfo
async def async_recognize_help(
self, title: str, org_meta: MetaBase
self, title: str, org_meta: MetaBase
) -> Optional[MediaInfo]:
"""
请求辅助识别,返回媒体信息(异步版本)
@@ -1417,7 +1417,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return await self.async_recognize_media(meta=org_meta)
async def async_recognize_by_path(
self, path: str, episode_group: Optional[str] = None
self, path: str, episode_group: Optional[str] = None
) -> Optional[Context]:
"""
根据文件路径识别媒体信息(异步版本)
@@ -1455,7 +1455,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return Context(meta_info=file_meta, media_info=mediainfo)
async def async_search(
self, title: str
self, title: str
) -> Tuple[Optional[MetaBase], List[MediaInfo]]:
"""
搜索媒体/人物信息(异步版本)
@@ -1502,7 +1502,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
@staticmethod
def _extract_year_from_tmdb(
tmdbinfo: dict, season: Optional[int] = None
tmdbinfo: dict, season: Optional[int] = None
) -> Optional[str]:
"""
从TMDB信息中提取年份
@@ -1522,11 +1522,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return year
def _match_tmdb_with_names(
self,
meta_names: list,
year: Optional[str],
mtype: MediaType,
season: Optional[int] = None,
self,
meta_names: list,
year: Optional[str],
mtype: MediaType,
season: Optional[int] = None,
) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息
@@ -1540,11 +1540,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
async def _async_match_tmdb_with_names(
self,
meta_names: list,
year: Optional[str],
mtype: MediaType,
season: Optional[int] = None,
self,
meta_names: list,
year: Optional[str],
mtype: MediaType,
season: Optional[int] = None,
) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息异步版本
@@ -1558,7 +1558,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
async def async_get_tmdbinfo_by_doubanid(
self, doubanid: str, mtype: MediaType = None
self, doubanid: str, mtype: MediaType = None
) -> Optional[dict]:
"""
根据豆瓣ID获取TMDB信息异步版本
@@ -1629,7 +1629,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return None
async def async_get_doubaninfo_by_tmdbid(
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
) -> Optional[dict]:
"""
根据TMDBID获取豆瓣信息异步版本

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ from urllib.parse import urljoin
from lxml import etree
from app.chain import ChainBase
from app.helper.slash import (
from app.helper.interaction import (
SlashInteractionManager,
build_navigation_buttons,
format_markdown_table,
@@ -1060,8 +1060,9 @@ class SiteChain(ChainBase):
original_chat_id=original_chat_id,
)
@staticmethod
def _format_site_list(
self, site_list: List[Site], channel: Optional[MessageChannel]
site_list: List[Site], channel: Optional[MessageChannel]
) -> str:
"""
根据渠道能力格式化站点列表。

View File

@@ -1,145 +1,18 @@
import re
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from threading import Lock
from typing import Dict, List, Optional, Tuple, Union
import uuid
from typing import List, Optional, Tuple, Union
from app.chain import ChainBase
from app.helper.slash import (
from app.helper.interaction import (
build_navigation_buttons,
page_items,
supports_interaction_buttons,
update_or_post_message,
update_or_post_message, skills_interaction_manager, PendingSkillsInteraction,
)
from app.helper.skill import SkillHelper, SkillInfo
from app.schemas import Notification
from app.schemas.types import MessageChannel
@dataclass
class PendingSkillsInteraction:
"""
记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
"""
request_id: str
user_id: str
channel: Optional[MessageChannel]
source: Optional[str]
username: Optional[str]
view: str = "root"
local_page: int = 0
market_page: int = 0
market_query: str = ""
awaiting_input: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
class SkillsInteractionManager:
"""
管理用户当前的技能交互状态。
每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
"""
_ttl = timedelta(hours=24)
def __init__(self):
self._by_id: Dict[str, PendingSkillsInteraction] = {}
self._by_user: Dict[str, str] = {}
self._lock = Lock()
def _cleanup_locked(self):
"""
清理超时会话,避免按钮回调无限积累。
"""
expire_before = datetime.now() - self._ttl
expired = [
request_id
for request_id, request in self._by_id.items()
if request.created_at < expire_before
]
for request_id in expired:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def create_or_replace(
self,
user_id: Union[str, int],
channel: Optional[MessageChannel],
source: Optional[str],
username: Optional[str],
) -> PendingSkillsInteraction:
"""
为用户创建新会话,并替换掉旧的技能交互状态。
"""
with self._lock:
self._cleanup_locked()
user_key = str(user_id)
old_request_id = self._by_user.get(user_key)
if old_request_id:
self._by_id.pop(old_request_id, None)
request_id = uuid.uuid4().hex[:12]
request = PendingSkillsInteraction(
request_id=request_id,
user_id=user_key,
channel=channel,
source=source,
username=username,
)
self._by_id[request_id] = request
self._by_user[user_key] = request_id
return request
def get_by_user(
self, user_id: Union[str, int]
) -> Optional[PendingSkillsInteraction]:
"""
按用户获取当前有效会话,供纯文本回复路由使用。
"""
with self._lock:
self._cleanup_locked()
request_id = self._by_user.get(str(user_id))
if not request_id:
return None
return self._by_id.get(request_id)
def get_by_id(
self, request_id: str, user_id: Union[str, int]
) -> Optional[PendingSkillsInteraction]:
"""
按请求 ID 获取会话,并校验会话归属用户。
"""
with self._lock:
self._cleanup_locked()
request = self._by_id.get(request_id)
if not request or str(request.user_id) != str(user_id):
return None
return request
def remove(self, request_id: str) -> None:
"""
主动结束会话,释放用户和请求 ID 的双向索引。
"""
with self._lock:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def clear(self):
"""
清空所有会话,主要用于测试场景。
"""
with self._lock:
self._by_id.clear()
self._by_user.clear()
skills_interaction_manager = SkillsInteractionManager()
class SkillsChain(ChainBase):
"""
处理 /skills 指令、按钮回调和文本式技能管理交互。
@@ -153,11 +26,11 @@ class SkillsChain(ChainBase):
self.skillhelper = SkillHelper()
def remote_manage(
self,
arg_str: str,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
self,
arg_str: str,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
):
"""
/skills 入口。创建新会话并渲染首屏菜单。
@@ -205,14 +78,14 @@ class SkillsChain(ChainBase):
return request_id, action, index
def handle_callback_interaction(
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> bool:
"""
处理按钮交互,并在同一条消息上刷新当前视图。
@@ -364,12 +237,12 @@ class SkillsChain(ChainBase):
return True
def handle_text_interaction(
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
) -> bool:
"""
处理不支持按钮渠道上的文本指令,也兼容用户直接回复文字操作。
@@ -660,42 +533,42 @@ class SkillsChain(ChainBase):
return True
def _install_market_skill(
self,
request: PendingSkillsInteraction,
page_index: int,
self,
request: PendingSkillsInteraction,
page_index: int,
) -> Tuple[bool, str]:
"""
按当前市场页的可见序号安装技能,避免跨页序号歧义。
"""
market_skills = self._get_market_skills(request=request)
page_items, page, _ = self._page_items(
items, page, _ = self._page_items(
items=market_skills,
page=request.market_page,
page_size=self._page_size(request.channel),
)
request.market_page = page
if page_index < 1 or page_index > len(page_items):
if page_index < 1 or page_index > len(items):
return False, "安装序号无效"
return self.skillhelper.install_market_skill(page_items[page_index - 1])
return self.skillhelper.install_market_skill(items[page_index - 1])
def _remove_local_skill(
self,
request: PendingSkillsInteraction,
page_index: int,
self,
request: PendingSkillsInteraction,
page_index: int,
) -> Tuple[bool, str]:
"""
按当前已安装页的可见序号删除技能,并拦截内置技能。
"""
local_skills = self.skillhelper.list_local_skills()
page_items, page, _ = self._page_items(
items, page, _ = self._page_items(
items=local_skills,
page=request.local_page,
page_size=self._page_size(request.channel),
)
request.local_page = page
if page_index < 1 or page_index > len(page_items):
if page_index < 1 or page_index > len(items):
return False, "删除序号无效"
target = page_items[page_index - 1]
target = items[page_index - 1]
if not target.removable:
return False, f"技能 {target.id} 是内置技能,不能删除"
return self.skillhelper.remove_local_skill(target.id)
@@ -713,15 +586,15 @@ class SkillsChain(ChainBase):
return self.skillhelper.remove_custom_market_source(target.source)
def _render_interaction(
self,
request: PendingSkillsInteraction,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
force_market_refresh: bool = False,
self,
request: PendingSkillsInteraction,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
force_market_refresh: bool = False,
) -> None:
"""
根据当前视图生成内容,并选择编辑原消息或发送新消息。
@@ -758,9 +631,9 @@ class SkillsChain(ChainBase):
)
def _build_root_view(
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False,
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False,
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建根菜单视图,汇总本地技能和市场概览。
@@ -809,14 +682,14 @@ class SkillsChain(ChainBase):
return "技能管理", "\n".join(text_lines), buttons
def _build_installed_view(
self,
request: PendingSkillsInteraction
self,
request: PendingSkillsInteraction
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建已安装技能视图,列出来源和可删除状态。
"""
local_skills = self.skillhelper.list_local_skills()
page_items, page, total_pages = self._page_items(
items, page, total_pages = self._page_items(
items=local_skills,
page=request.local_page,
page_size=self._page_size(request.channel),
@@ -824,11 +697,11 @@ class SkillsChain(ChainBase):
request.local_page = page
text_lines = [f"{page + 1}/{total_pages} 页,共 {len(local_skills)} 个技能"]
if not page_items:
if not items:
text_lines.append("")
text_lines.append("当前没有已安装技能")
else:
for index, skill in enumerate(page_items, start=1):
for index, skill in enumerate(items, start=1):
action = "可删除" if skill.removable else "内置不可删"
text_lines.extend(
[
@@ -869,9 +742,9 @@ class SkillsChain(ChainBase):
return "已安装技能", "\n".join(text_lines), buttons
def _build_market_view(
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False,
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False,
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建技能市场视图,仅展示尚未安装的技能。
@@ -880,7 +753,7 @@ class SkillsChain(ChainBase):
request=request,
force_market_refresh=force_market_refresh,
)
page_items, page, total_pages = self._page_items(
items, page, total_pages = self._page_items(
items=market_skills,
page=request.market_page,
page_size=self._page_size(request.channel),
@@ -897,14 +770,14 @@ class SkillsChain(ChainBase):
"搜索输入中:直接回复关键词即可筛选市场技能,回复 `取消` 结束输入。",
]
)
if not page_items:
if not items:
text_lines.append("")
if request.market_query:
text_lines.append("当前搜索没有匹配的市场技能")
else:
text_lines.append("当前没有可安装的市场技能")
else:
for index, skill in enumerate(page_items, start=1):
for index, skill in enumerate(items, start=1):
text_lines.extend(
[
"",
@@ -970,8 +843,8 @@ class SkillsChain(ChainBase):
return "技能市场", "\n".join(text_lines), buttons
def _build_sources_view(
self,
request: PendingSkillsInteraction,
self,
request: PendingSkillsInteraction,
) -> Tuple[str, str, Optional[List[List[dict]]]]:
"""
构建技能源管理视图,提供自定义 GitHub 源的增删入口。
@@ -1052,9 +925,9 @@ class SkillsChain(ChainBase):
@staticmethod
def _page_items(
items: List[SkillInfo],
page: int,
page_size: int,
items: List[SkillInfo],
page: int,
page_size: int,
) -> Tuple[List[SkillInfo], int, int]:
"""
返回当前页的数据,并把页码钳制到有效范围内。
@@ -1080,9 +953,9 @@ class SkillsChain(ChainBase):
@staticmethod
def _navigation_buttons(
request: PendingSkillsInteraction,
page: int,
total_pages: int,
request: PendingSkillsInteraction,
page: int,
total_pages: int,
) -> List[List[dict]]:
"""
为分页视图生成上一页和下一页按钮。
@@ -1095,16 +968,16 @@ class SkillsChain(ChainBase):
)
def _update_or_post_message(
self,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
title: str,
text: str,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
self,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
title: str,
text: str,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
"""
优先编辑原消息,编辑失败时再回退为发送新消息。
@@ -1136,9 +1009,9 @@ class SkillsChain(ChainBase):
return "请输入 1、2、3、搜索 <关键词>、刷新 或 退出"
def _get_market_skills(
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False,
self,
request: PendingSkillsInteraction,
force_market_refresh: bool = False,
) -> List[SkillInfo]:
"""
获取当前 /skills 会话可见的市场技能,并应用搜索词过滤。
@@ -1183,8 +1056,8 @@ class SkillsChain(ChainBase):
@staticmethod
def _apply_market_search(
request: PendingSkillsInteraction,
query: str,
request: PendingSkillsInteraction,
query: str,
) -> None:
"""
将会话切到市场搜索结果视图,并重置分页状态。

View File

@@ -1,6 +1,7 @@
import copy
import json
import random
import re
import threading
import time
from datetime import datetime
@@ -11,7 +12,7 @@ from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.helper.slash import (
from app.helper.interaction import (
SlashInteractionManager,
build_navigation_buttons,
format_markdown_table,

View File

@@ -1 +0,0 @@
from .cloudflare import under_challenge

626
app/helper/interaction.py Normal file
View File

@@ -0,0 +1,626 @@
import math
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from threading import Lock
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from app.core.context import MediaInfo
from app.core.meta import MetaBase
from app.schemas import Notification
from app.schemas.message import ChannelCapabilityManager
from app.schemas.types import MessageChannel
@dataclass
class PendingSlashInteraction:
"""
通用 slash 命令交互上下文。
"""
request_id: str
user_id: str
channel: Optional[MessageChannel]
source: Optional[str]
username: Optional[str]
command: str
page: int = 0
awaiting_input: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
class SlashInteractionManager:
"""
管理单个 slash 命令的交互会话。
"""
_ttl = timedelta(hours=24)
def __init__(self):
self._by_id: Dict[str, PendingSlashInteraction] = {}
self._by_user: Dict[str, str] = {}
self._lock = Lock()
def _cleanup_locked(self) -> None:
expire_before = datetime.now() - self._ttl
expired = [
request_id
for request_id, request in self._by_id.items()
if request.created_at < expire_before
]
for request_id in expired:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def create_or_replace(
self,
user_id: Union[str, int],
command: str,
channel: Optional[MessageChannel],
source: Optional[str],
username: Optional[str],
) -> PendingSlashInteraction:
with self._lock:
self._cleanup_locked()
user_key = str(user_id)
old_request_id = self._by_user.get(user_key)
if old_request_id:
self._by_id.pop(old_request_id, None)
request = PendingSlashInteraction(
request_id=uuid.uuid4().hex[:12],
user_id=user_key,
command=command,
channel=channel,
source=source,
username=username,
)
self._by_id[request.request_id] = request
self._by_user[user_key] = request.request_id
return request
def get_by_user(
self, user_id: Union[str, int]
) -> Optional[PendingSlashInteraction]:
with self._lock:
self._cleanup_locked()
request_id = self._by_user.get(str(user_id))
if not request_id:
return None
return self._by_id.get(request_id)
def get_by_id(
self, request_id: str, user_id: Union[str, int]
) -> Optional[PendingSlashInteraction]:
with self._lock:
self._cleanup_locked()
request = self._by_id.get(request_id)
if not request or str(request.user_id) != str(user_id):
return None
return request
def remove(self, request_id: str) -> None:
with self._lock:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def clear(self) -> None:
with self._lock:
self._by_id.clear()
self._by_user.clear()
def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
"""
渠道同时支持按钮和回调时,优先使用按钮交互。
"""
return bool(
channel
and ChannelCapabilityManager.supports_buttons(channel)
and ChannelCapabilityManager.supports_callbacks(channel)
)
def supports_markdown(channel: Optional[MessageChannel]) -> bool:
"""
仅在支持 Markdown 的渠道上输出 Markdown 内容。
"""
return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
def page_items(
items: Sequence[Any],
page: int,
page_size: int,
) -> Tuple[List[Any], int, int]:
"""
对列表做分页并规范化页码。
"""
total = len(items)
if total == 0:
return [], 0, 1
total_pages = max(1, math.ceil(total / max(1, page_size)))
page = min(max(0, page), total_pages - 1)
start = page * page_size
end = start + page_size
return list(items[start:end]), page, total_pages
def build_navigation_buttons(
prefix: str,
request: Any,
page: int,
total_pages: int,
) -> List[List[dict]]:
"""
构造标准上一页/下一页按钮。
"""
buttons = []
nav_row = []
if page > 0:
nav_row.append(
{
"text": "⬅️ 上一页",
"callback_data": f"{prefix}:{request.request_id}:page-prev",
}
)
if page < total_pages - 1:
nav_row.append(
{
"text": "下一页 ➡️",
"callback_data": f"{prefix}:{request.request_id}:page-next",
}
)
if nav_row:
buttons.append(nav_row)
return buttons
def update_or_post_message(
chain,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
title: str,
text: str,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
"""
优先编辑原消息,失败时回退为发送新消息。
"""
if (
original_message_id
and original_chat_id
and ChannelCapabilityManager.supports_editing(channel)
):
edited = chain.edit_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id,
title=title,
text=text,
buttons=buttons,
)
if edited:
return
chain.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=title,
text=text,
buttons=buttons,
)
)
def escape_markdown_table_cell(value: object) -> str:
"""
最小化转义 Markdown 表格中的特殊字符。
"""
text = str(value or "").replace("\n", "<br>")
return text.replace("|", "\\|")
def format_markdown_table(
headers: Sequence[str],
rows: Sequence[Sequence[object]],
) -> str:
"""
生成 Markdown 表格文本。
"""
header_line = (
"| "
+ " | ".join(escape_markdown_table_cell(item) for item in headers)
+ " |"
)
separator_line = "| " + " | ".join("---" for _ in headers) + " |"
data_lines = [
"| "
+ " | ".join(escape_markdown_table_cell(item) for item in row)
+ " |"
for row in rows
]
return "\n".join([header_line, separator_line, *data_lines])
@dataclass
class PendingMediaInteraction:
"""
记录一次搜索/下载/订阅交互的当前上下文。
"""
request_id: str
user_id: str
channel: Optional[MessageChannel]
source: Optional[str]
username: Optional[str]
action: str
keyword: str
phase: str = "media"
page: int = 0
title: str = ""
meta: Optional[MetaBase] = None
current_media: Optional[MediaInfo] = None
items: List[Any] = field(default_factory=list)
created_at: datetime = field(default_factory=datetime.now)
class MediaInteractionManager:
"""
管理用户当前激活的媒体交互状态。
每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。
"""
_ttl = timedelta(hours=24)
def __init__(self):
self._by_id: Dict[str, PendingMediaInteraction] = {}
self._by_user: Dict[str, str] = {}
self._lock = Lock()
def _cleanup_locked(self) -> None:
"""
清理超时会话,避免内存中残留旧交互状态。
"""
expire_before = datetime.now() - self._ttl
expired = [
request_id
for request_id, request in self._by_id.items()
if request.created_at < expire_before
]
for request_id in expired:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def create_or_replace(
self,
user_id: Union[str, int],
channel: Optional[MessageChannel],
source: Optional[str],
username: Optional[str],
action: str,
keyword: str,
title: str = "",
meta: Optional[MetaBase] = None,
items: Optional[List[Any]] = None,
) -> PendingMediaInteraction:
"""
为用户创建新的交互状态,并替换旧会话。
"""
with self._lock:
self._cleanup_locked()
user_key = str(user_id)
old_request_id = self._by_user.get(user_key)
if old_request_id:
self._by_id.pop(old_request_id, None)
request = PendingMediaInteraction(
request_id=uuid.uuid4().hex[:12],
user_id=user_key,
channel=channel,
source=source,
username=username,
action=action,
keyword=keyword,
title=title,
meta=meta,
items=list(items or []),
)
self._by_id[request.request_id] = request
self._by_user[user_key] = request.request_id
return request
def get_by_user(
self, user_id: Union[str, int]
) -> Optional[PendingMediaInteraction]:
"""
按用户读取当前会话,供文本回复和旧按钮兼容使用。
"""
with self._lock:
self._cleanup_locked()
request_id = self._by_user.get(str(user_id))
if not request_id:
return None
return self._by_id.get(request_id)
def get_by_id(
self, request_id: str, user_id: Union[str, int]
) -> Optional[PendingMediaInteraction]:
"""
按请求 ID 读取会话,并校验用户归属。
"""
with self._lock:
self._cleanup_locked()
request = self._by_id.get(request_id)
if not request or str(request.user_id) != str(user_id):
return None
return request
def remove(self, request_id: str) -> None:
"""
主动结束一条会话。
"""
with self._lock:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def clear(self) -> None:
"""
清空所有交互状态,主要用于测试。
"""
with self._lock:
self._by_id.clear()
self._by_user.clear()
media_interaction_manager = MediaInteractionManager()
@dataclass(frozen=True)
class AgentInteractionOption:
"""
Agent 交互选项。
"""
label: str
value: str
@dataclass
class PendingAgentInteraction:
"""
待处理的 Agent 客户端交互请求。
"""
request_id: str
session_id: str
user_id: str
channel: Optional[str]
source: Optional[str]
username: Optional[str]
title: Optional[str]
prompt: str
options: List[AgentInteractionOption]
created_at: datetime = field(default_factory=datetime.now)
class AgentInteractionManager:
"""
管理 Agent 发起的客户端交互请求。
"""
_ttl = timedelta(hours=24)
def __init__(self):
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
self._lock = Lock()
def _cleanup_locked(self) -> None:
expire_before = datetime.now() - self._ttl
expired_ids = [
request_id
for request_id, request in self._pending_interactions.items()
if request.created_at < expire_before
]
for request_id in expired_ids:
self._pending_interactions.pop(request_id, None)
def create_request(
self,
session_id: str,
user_id: str,
channel: Optional[str],
source: Optional[str],
username: Optional[str],
title: Optional[str],
prompt: str,
options: List[AgentInteractionOption],
) -> PendingAgentInteraction:
"""
创建一条待用户确认的 Agent 交互请求。
"""
with self._lock:
self._cleanup_locked()
request_id = uuid.uuid4().hex[:12]
while request_id in self._pending_interactions:
request_id = uuid.uuid4().hex[:12]
request = PendingAgentInteraction(
request_id=request_id,
session_id=session_id,
user_id=str(user_id),
channel=channel,
source=source,
username=username,
title=title,
prompt=prompt,
options=options,
)
self._pending_interactions[request_id] = request
return request
def resolve(
self,
request_id: str,
option_index: int,
user_id: Optional[str] = None,
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
"""
消费一条 Agent 交互请求,并返回选中的选项。
"""
with self._lock:
self._cleanup_locked()
request = self._pending_interactions.get(request_id)
if not request:
return None
if user_id is not None and str(request.user_id) != str(user_id):
return None
if option_index < 1 or option_index > len(request.options):
return None
option = request.options[option_index - 1]
self._pending_interactions.pop(request_id, None)
return request, option
def clear(self) -> None:
"""
清空所有 Agent 交互请求。
"""
with self._lock:
self._pending_interactions.clear()
agent_interaction_manager = AgentInteractionManager()
@dataclass
class PendingSkillsInteraction:
"""
记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
"""
request_id: str
user_id: str
channel: Optional[MessageChannel]
source: Optional[str]
username: Optional[str]
view: str = "root"
local_page: int = 0
market_page: int = 0
market_query: str = ""
awaiting_input: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
class SkillsInteractionManager:
"""
管理用户当前的技能交互状态。
每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
"""
_ttl = timedelta(hours=24)
def __init__(self):
self._by_id: Dict[str, PendingSkillsInteraction] = {}
self._by_user: Dict[str, str] = {}
self._lock = Lock()
def _cleanup_locked(self):
"""
清理超时会话,避免按钮回调无限积累。
"""
expire_before = datetime.now() - self._ttl
expired = [
request_id
for request_id, request in self._by_id.items()
if request.created_at < expire_before
]
for request_id in expired:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def create_or_replace(
self,
user_id: Union[str, int],
channel: Optional[MessageChannel],
source: Optional[str],
username: Optional[str],
) -> PendingSkillsInteraction:
"""
为用户创建新会话,并替换掉旧的技能交互状态。
"""
with self._lock:
self._cleanup_locked()
user_key = str(user_id)
old_request_id = self._by_user.get(user_key)
if old_request_id:
self._by_id.pop(old_request_id, None)
request_id = uuid.uuid4().hex[:12]
request = PendingSkillsInteraction(
request_id=request_id,
user_id=user_key,
channel=channel,
source=source,
username=username,
)
self._by_id[request_id] = request
self._by_user[user_key] = request_id
return request
def get_by_user(
self, user_id: Union[str, int]
) -> Optional[PendingSkillsInteraction]:
"""
按用户获取当前有效会话,供纯文本回复路由使用。
"""
with self._lock:
self._cleanup_locked()
request_id = self._by_user.get(str(user_id))
if not request_id:
return None
return self._by_id.get(request_id)
def get_by_id(
self, request_id: str, user_id: Union[str, int]
) -> Optional[PendingSkillsInteraction]:
"""
按请求 ID 获取会话,并校验会话归属用户。
"""
with self._lock:
self._cleanup_locked()
request = self._by_id.get(request_id)
if not request or str(request.user_id) != str(user_id):
return None
return request
def remove(self, request_id: str) -> None:
"""
主动结束会话,释放用户和请求 ID 的双向索引。
"""
with self._lock:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def clear(self):
"""
清空所有会话,主要用于测试场景。
"""
with self._lock:
self._by_id.clear()
self._by_user.clear()
skills_interaction_manager = SkillsInteractionManager()

View File

@@ -1,244 +0,0 @@
import math
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from threading import Lock
from typing import Dict, List, Optional, Sequence, Tuple, Union
from app.schemas import Notification
from app.schemas.message import ChannelCapabilityManager
from app.schemas.types import MessageChannel
@dataclass
class PendingSlashInteraction:
"""
通用 slash 命令交互上下文。
"""
request_id: str
user_id: str
channel: Optional[MessageChannel]
source: Optional[str]
username: Optional[str]
command: str
page: int = 0
awaiting_input: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
class SlashInteractionManager:
"""
管理单个 slash 命令的交互会话。
"""
_ttl = timedelta(hours=24)
def __init__(self):
self._by_id: Dict[str, PendingSlashInteraction] = {}
self._by_user: Dict[str, str] = {}
self._lock = Lock()
def _cleanup_locked(self) -> None:
expire_before = datetime.now() - self._ttl
expired = [
request_id
for request_id, request in self._by_id.items()
if request.created_at < expire_before
]
for request_id in expired:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def create_or_replace(
self,
user_id: Union[str, int],
command: str,
channel: Optional[MessageChannel],
source: Optional[str],
username: Optional[str],
) -> PendingSlashInteraction:
with self._lock:
self._cleanup_locked()
user_key = str(user_id)
old_request_id = self._by_user.get(user_key)
if old_request_id:
self._by_id.pop(old_request_id, None)
request = PendingSlashInteraction(
request_id=uuid.uuid4().hex[:12],
user_id=user_key,
command=command,
channel=channel,
source=source,
username=username,
)
self._by_id[request.request_id] = request
self._by_user[user_key] = request.request_id
return request
def get_by_user(
self, user_id: Union[str, int]
) -> Optional[PendingSlashInteraction]:
with self._lock:
self._cleanup_locked()
request_id = self._by_user.get(str(user_id))
if not request_id:
return None
return self._by_id.get(request_id)
def get_by_id(
self, request_id: str, user_id: Union[str, int]
) -> Optional[PendingSlashInteraction]:
with self._lock:
self._cleanup_locked()
request = self._by_id.get(request_id)
if not request or str(request.user_id) != str(user_id):
return None
return request
def remove(self, request_id: str) -> None:
with self._lock:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def clear(self) -> None:
with self._lock:
self._by_id.clear()
self._by_user.clear()
def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
"""
渠道同时支持按钮和回调时,优先使用按钮交互。
"""
return bool(
channel
and ChannelCapabilityManager.supports_buttons(channel)
and ChannelCapabilityManager.supports_callbacks(channel)
)
def supports_markdown(channel: Optional[MessageChannel]) -> bool:
"""
仅在支持 Markdown 的渠道上输出 Markdown 内容。
"""
return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
def page_items(
items: Sequence,
page: int,
page_size: int,
) -> Tuple[List, int, int]:
"""
对列表做分页并规范化页码。
"""
total = len(items)
if total == 0:
return [], 0, 1
total_pages = max(1, math.ceil(total / max(1, page_size)))
page = min(max(0, page), total_pages - 1)
start = page * page_size
end = start + page_size
return list(items[start:end]), page, total_pages
def build_navigation_buttons(
prefix: str,
request: PendingSlashInteraction,
page: int,
total_pages: int,
) -> List[List[dict]]:
"""
构造标准上一页/下一页按钮。
"""
buttons = []
nav_row = []
if page > 0:
nav_row.append(
{
"text": "⬅️ 上一页",
"callback_data": f"{prefix}:{request.request_id}:page-prev",
}
)
if page < total_pages - 1:
nav_row.append(
{
"text": "下一页 ➡️",
"callback_data": f"{prefix}:{request.request_id}:page-next",
}
)
if nav_row:
buttons.append(nav_row)
return buttons
def update_or_post_message(
chain,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
title: str,
text: str,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
"""
优先编辑原消息,失败时回退为发送新消息。
"""
if (
original_message_id
and original_chat_id
and ChannelCapabilityManager.supports_editing(channel)
):
edited = chain.edit_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id,
title=title,
text=text,
buttons=buttons,
)
if edited:
return
chain.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=title,
text=text,
buttons=buttons,
)
)
def escape_markdown_table_cell(value: object) -> str:
"""
最小化转义 Markdown 表格中的特殊字符。
"""
text = str(value or "").replace("\n", "<br>")
text = text.replace("|", "\\|")
return text
def format_markdown_table(headers: Sequence[str], rows: Sequence[Sequence[object]]) -> str:
"""
生成 Markdown 表格文本。
"""
header_line = "| " + " | ".join(escape_markdown_table_cell(item) for item in headers) + " |"
separator_line = "| " + " | ".join("---" for _ in headers) + " |"
data_lines = [
"| "
+ " | ".join(escape_markdown_table_cell(item) for item in row)
+ " |"
for row in rows
]
return "\n".join([header_line, separator_line, *data_lines])

View File

@@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import (
AskUserChoiceTool,
UserChoiceOptionInput,
)
from app.chain.interaction import (
from app.helper.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)

View File

@@ -9,7 +9,7 @@ sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
setattr(sys.modules["transmission_rpc"], "File", object)
sys.modules.setdefault("psutil", ModuleType("psutil"))
from app.chain.interaction import MediaInteractionChain, media_interaction_manager
from app.chain.media import MediaChain, media_interaction_manager
from app.chain.message import MessageChain
from app.core.context import MediaInfo
from app.core.meta import MetaBase
@@ -43,7 +43,7 @@ class TestMediaInteraction(unittest.TestCase):
self.assertIsNotNone(request)
with patch.object(chain, "_record_user_message"), patch(
"app.chain.message.MediaInteractionChain.handle_text_interaction",
"app.chain.message.MediaChain.handle_text_interaction",
return_value=True,
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
chain.handle_message(
@@ -72,7 +72,7 @@ class TestMediaInteraction(unittest.TestCase):
)
with patch(
"app.chain.message.MediaInteractionChain.handle_callback_interaction",
"app.chain.message.MediaChain.handle_callback_interaction",
return_value=True,
) as handle_callback:
chain._handle_callback(
@@ -86,7 +86,7 @@ class TestMediaInteraction(unittest.TestCase):
handle_callback.assert_called_once()
def test_media_interaction_starts_search_and_posts_media_list(self):
chain = MediaInteractionChain()
chain = MediaChain()
meta = self._build_meta("星际穿越")
medias = [
MediaInfo(title="星际穿越", year="2014"),
@@ -94,7 +94,7 @@ class TestMediaInteraction(unittest.TestCase):
]
with patch(
"app.chain.interaction.MediaChain.search",
"app.chain.media.MediaChain.search",
return_value=(meta, medias),
), patch.object(chain, "post_medias_message") as post_medias_message:
handled = chain.handle_text_interaction(
@@ -119,7 +119,7 @@ class TestMediaInteraction(unittest.TestCase):
self.assertEqual(len(request.items), 2)
def test_media_interaction_legacy_page_callback_updates_existing_request(self):
chain = MediaInteractionChain()
chain = MediaChain()
request = media_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Telegram,