diff --git a/app/chain/message.py b/app/chain/message.py index e6648e75..4c717e17 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -3,6 +3,7 @@ import re import time from datetime import datetime, timedelta from typing import Any, Optional, Dict, Union, List +from urllib.parse import unquote import base64 @@ -1338,8 +1339,39 @@ class MessageChain(ChainBase): "download_wechat_media_bytes", media_ref=audio_ref, source=source ) filename = "input.amr" + elif audio_ref.startswith("slack://file/"): + content = self.run_module( + "download_slack_file_bytes", file_ref=audio_ref, source=source + ) + filename = self._guess_audio_filename(audio_ref, default="input.ogg") + elif audio_ref.startswith("discord://file/"): + content = self.run_module( + "download_discord_file_bytes", file_ref=audio_ref, source=source + ) + filename = self._guess_audio_filename(audio_ref, default="input.ogg") + elif audio_ref.startswith("qq://file/"): + content = self.run_module( + "download_qq_file_bytes", file_ref=audio_ref, source=source + ) + filename = self._guess_audio_filename(audio_ref, default="input.ogg") + elif audio_ref.startswith("vocechat://file/"): + content = self.run_module( + "download_vocechat_file_bytes", file_ref=audio_ref, source=source + ) + filename = self._guess_audio_filename(audio_ref, default="input.ogg") + elif audio_ref.startswith("synology://file/"): + content = self.run_module( + "download_synologychat_file_bytes", + file_ref=audio_ref, + source=source, + ) + filename = self._guess_audio_filename(audio_ref, default="input.ogg") elif audio_ref.startswith("wxbot://voice"): continue + elif audio_ref.startswith("http"): + resp = RequestUtils(timeout=30).get_res(audio_ref) + content = resp.content if resp and resp.content else None + filename = self._guess_audio_filename(audio_ref, default="input.ogg") else: logger.debug( "暂不支持的语音引用: channel=%s, source=%s, ref=%s", @@ -1349,6 +1381,15 @@ class MessageChain(ChainBase): ) continue + if not content: + logger.warning( + "语音下载失败,跳过识别: channel=%s, source=%s, ref=%s", + channel.value if channel else None, + source, + audio_ref, + ) + continue + transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename) if transcript: transcripts.append(transcript) @@ -1364,6 +1405,23 @@ class MessageChain(ChainBase): return "\n".join(transcripts).strip() if transcripts else None + @staticmethod + def _guess_audio_filename(audio_ref: str, default: str = "input.ogg") -> str: + """ + 根据引用中的扩展名推测音频文件名,便于 STT 服务识别格式。 + """ + if not audio_ref: + return default + raw_ref = unquote(audio_ref).split("?", 1)[0].split("#", 1)[0] + match = re.search( + r"([^/]+\.(mp3|m4a|wav|ogg|oga|opus|aac|amr|flac|mpga|mpeg|webm))$", + raw_ref, + flags=re.IGNORECASE, + ) + if match: + return match.group(1) + return default + def _download_images_to_base64( self, images: List[str], channel: MessageChannel, source: str ) -> List[str]: diff --git a/app/modules/discord/__init__.py b/app/modules/discord/__init__.py index 43f1b99e..9efb46f9 100644 --- a/app/modules/discord/__init__.py +++ b/app/modules/discord/__init__.py @@ -1,4 +1,5 @@ import json +from urllib.parse import quote, unquote from typing import Optional, Union, List, Tuple, Any from app.core.context import MediaInfo, Context @@ -6,6 +7,7 @@ from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse from app.schemas.types import ModuleType +from app.utils.http import RequestUtils try: from app.modules.discord.discord import Discord @@ -25,6 +27,20 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): ".tiff", ".svg", ) + _AUDIO_SUFFIXES = ( + ".mp3", + ".m4a", + ".wav", + ".ogg", + ".oga", + ".opus", + ".aac", + ".amr", + ".flac", + ".mpga", + ".mpeg", + ".webm", + ) def init_module(self) -> None: """ @@ -142,10 +158,12 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): text = msg_json.get("text") chat_id = msg_json.get("chat_id") images = self._extract_images(msg_json) - if (text or images) and userid: + audio_refs = self._extract_audio_refs(msg_json) + if (text or images or audio_refs) and userid: logger.info( f"收到来自 {client_config.name} 的 Discord 消息:" - f"userid={userid}, username={username}, text={text}, images={len(images) if images else 0}" + f"userid={userid}, username={username}, text={text}, " + f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" ) return CommingMessage( channel=MessageChannel.Discord, @@ -155,6 +173,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): text=text, chat_id=str(chat_id) if chat_id else None, images=images, + audio_refs=audio_refs, ) return None @@ -181,6 +200,39 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): images.append(url) return images if images else None + @classmethod + def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]: + """ + 从Discord消息中提取音频URL + """ + attachments = msg_json.get("attachments", []) + if not attachments: + return None + audio_refs = [] + for attachment in attachments: + url = attachment.get("url") or attachment.get("proxy_url") + if not url: + continue + content_type = (attachment.get("content_type") or "").lower() + filename = (attachment.get("filename") or "").lower() + if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES): + audio_refs.append(f"discord://file/{quote(url, safe='')}") + return audio_refs if audio_refs else None + + def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: + """ + 下载Discord附件并返回原始字节 + """ + if not file_ref or not file_ref.startswith("discord://file/"): + return None + if not self.get_config(source): + return None + file_url = unquote(file_ref.replace("discord://file/", "", 1)) + resp = RequestUtils(timeout=30).get_res(file_url) + if resp and resp.content: + return resp.content + return None + def post_message(self, message: Notification, **kwargs) -> None: """ 发送通知消息 diff --git a/app/modules/qqbot/__init__.py b/app/modules/qqbot/__init__.py index e284bfba..8ef699f7 100644 --- a/app/modules/qqbot/__init__.py +++ b/app/modules/qqbot/__init__.py @@ -5,6 +5,7 @@ QQ Bot 通知模块 """ import json +from urllib.parse import quote, unquote from typing import Optional, List, Tuple, Union, Any from app.core.context import MediaInfo, Context @@ -13,6 +14,7 @@ from app.modules import _ModuleBase, _MessageBase from app.modules.qqbot.qqbot import QQBot from app.schemas import CommingMessage, MessageChannel, Notification from app.schemas.types import ModuleType +from app.utils.http import RequestUtils class QQBotModule(_ModuleBase, _MessageBase[QQBot]): @@ -28,6 +30,20 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): ".tiff", ".svg", ) + _AUDIO_SUFFIXES = ( + ".mp3", + ".m4a", + ".wav", + ".ogg", + ".oga", + ".opus", + ".aac", + ".amr", + ".flac", + ".mpga", + ".mpeg", + ".webm", + ) def init_module(self) -> None: self.stop() @@ -90,7 +106,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): msg_type = msg_body.get("type") content = (msg_body.get("content") or "").strip() images = self._extract_images(msg_body) - if not content and not images: + audio_refs = self._extract_audio_refs(msg_body) + if not content and not images and not audio_refs: return None if msg_type == "C2C_MESSAGE_CREATE": @@ -100,7 +117,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): return None logger.info( f"收到 QQ 私聊消息: userid={user_openid}, " - f"text={(content or '')[:50]}..., images={len(images) if images else 0}" + f"text={(content or '')[:50]}..., images={len(images) if images else 0}, " + f"audios={len(audio_refs) if audio_refs else 0}" ) return CommingMessage( channel=MessageChannel.QQ, @@ -109,6 +127,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): username=user_openid, text=content, images=images, + audio_refs=audio_refs, ) elif msg_type == "GROUP_AT_MESSAGE_CREATE": author = msg_body.get("author", {}) @@ -118,7 +137,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): userid = f"group:{group_openid}" if group_openid else member_openid logger.info( f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, " - f"text={(content or '')[:50]}..., images={len(images) if images else 0}" + f"text={(content or '')[:50]}..., images={len(images) if images else 0}, " + f"audios={len(audio_refs) if audio_refs else 0}" ) return CommingMessage( channel=MessageChannel.QQ, @@ -127,6 +147,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): username=member_openid or group_openid, text=content, images=images, + audio_refs=audio_refs, ) return None @@ -175,6 +196,50 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): deduped.append(image) return deduped or None + @classmethod + def _extract_audio_refs(cls, msg_body: dict) -> Optional[List[str]]: + audio_refs: List[str] = [] + attachments = msg_body.get("attachments") or [] + if isinstance(attachments, list): + for attachment in attachments: + if not isinstance(attachment, dict): + continue + url = attachment.get("url") or attachment.get("proxy_url") + if not url: + continue + content_type = ( + attachment.get("content_type") + or attachment.get("mime_type") + or "" + ).lower() + filename = ( + attachment.get("filename") + or attachment.get("name") + or "" + ).lower() + if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES): + audio_refs.append(f"qq://file/{quote(url, safe='')}") + + deduped = [] + for audio_ref in audio_refs: + if audio_ref not in deduped: + deduped.append(audio_ref) + return deduped or None + + def download_qq_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: + """ + 下载QQ音频附件并返回原始字节 + """ + if not file_ref or not file_ref.startswith("qq://file/"): + return None + if not self.get_config(source): + return None + file_url = unquote(file_ref.replace("qq://file/", "", 1)) + resp = RequestUtils(timeout=30).get_res(file_url) + if resp and resp.content: + return resp.content + return None + def post_message(self, message: Notification, **kwargs) -> None: for conf in self.get_configs().values(): if not self.check_message(message, conf.name): diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 657adc18..722e1ad7 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -1,5 +1,6 @@ import json import re +from urllib.parse import quote, unquote from typing import Optional, Union, List, Tuple, Any from app.core.context import MediaInfo, Context @@ -11,6 +12,21 @@ from app.schemas.types import ModuleType class SlackModule(_ModuleBase, _MessageBase[Slack]): + _AUDIO_SUFFIXES = ( + ".mp3", + ".m4a", + ".wav", + ".ogg", + ".oga", + ".opus", + ".aac", + ".amr", + ".flac", + ".mpga", + ".mpeg", + ".webm", + ) + def init_module(self) -> None: """ 初始化模块 @@ -204,11 +220,13 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return None if msg_json: images = None + audio_refs = None if msg_json.get("type") == "message": userid = msg_json.get("user") text = msg_json.get("text") username = msg_json.get("user") images = self._extract_images(msg_json) + audio_refs = self._extract_audio_refs(msg_json) elif msg_json.get("type") == "block_actions": userid = msg_json.get("user", {}).get("id") callback_data = msg_json.get("actions")[0].get("value") @@ -251,6 +269,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): ).strip() username = "" images = self._extract_images(msg_json.get("event", {})) + audio_refs = self._extract_audio_refs(msg_json.get("event", {})) elif msg_json.get("type") == "shortcut": userid = msg_json.get("user", {}).get("id") text = msg_json.get("callback_id") @@ -262,7 +281,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): else: return None logger.info( - f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, text={text}, images={len(images) if images else 0}" + f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, " + f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" ) return CommingMessage( channel=MessageChannel.Slack, @@ -271,6 +291,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): username=username, text=text, images=images, + audio_refs=audio_refs, ) return None @@ -297,6 +318,29 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): images.append(url) return images if images else None + @classmethod + def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]: + """ + 从Slack消息中提取音频文件引用 + """ + files = msg_json.get("files", []) + if not files: + return None + audio_refs = [] + for file in files: + file_type = str(file.get("type", "")).lower() + file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}" + mime_type = str(file.get("mimetype", "")).lower() + if ( + file_type == "audio" + or mime_type.startswith("audio/") + or file_ext in cls._AUDIO_SUFFIXES + ): + url = file.get("url_private_download") or file.get("url_private") + if url: + audio_refs.append(f"slack://file/{quote(url, safe='')}") + return audio_refs if audio_refs else None + def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]: """ 下载Slack文件并转为data URL @@ -318,6 +362,25 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return f"data:{mime_type};base64,{base64.b64encode(content).decode()}" return None + def download_slack_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: + """ + 下载Slack音频文件并返回原始字节 + """ + if not file_ref or not file_ref.startswith("slack://file/"): + return None + config = self.get_config(source) + if not config: + return None + client = self.get_instance(config.name) + if not client: + return None + file_url = unquote(file_ref.replace("slack://file/", "", 1)) + file_data = client.download_file(file_url) + if file_data: + content, _ = file_data + return content + return None + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index 12fca4de..bb0390c9 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -1,5 +1,6 @@ -from typing import Optional, Union, List, Tuple, Any import json +from typing import Optional, Union, List, Tuple, Any +from urllib.parse import quote, unquote from app.core.context import MediaInfo, Context from app.log import logger @@ -7,6 +8,7 @@ from app.modules import _ModuleBase, _MessageBase from app.modules.synologychat.synologychat import SynologyChat from app.schemas import MessageChannel, CommingMessage, Notification from app.schemas.types import ModuleType +from app.utils.http import RequestUtils class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): @@ -20,6 +22,20 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): ".tiff", ".svg", ) + _AUDIO_SUFFIXES = ( + ".mp3", + ".m4a", + ".wav", + ".ogg", + ".oga", + ".opus", + ".aac", + ".amr", + ".flac", + ".mpga", + ".mpeg", + ".webm", + ) def init_module(self) -> None: """ @@ -108,14 +124,16 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): # 获取用户名 user_name = message.get("username") images = self._extract_images(message) - if (text or images) and user_id: + audio_refs = self._extract_audio_refs(message) + if (text or images or audio_refs) and user_id: logger.info( f"收到来自 {client_config.name} 的SynologyChat消息:" - f"userid={user_id}, username={user_name}, text={text}, images={len(images) if images else 0}" + f"userid={user_id}, username={user_name}, text={text}, " + f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" ) return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name, userid=user_id, username=user_name, text=text or "", - images=images) + images=images, audio_refs=audio_refs) except Exception as err: logger.debug(f"解析SynologyChat消息失败:{str(err)}") return None @@ -151,6 +169,49 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): deduped.append(image) return deduped or None + @classmethod + def _extract_audio_refs(cls, message: dict) -> Optional[List[str]]: + audio_refs = [] + for key in ("audio_url", "voice_url", "file_url"): + value = message.get(key) + if isinstance(value, str) and cls._looks_like_audio(value): + audio_refs.append(f"synology://file/{quote(value, safe='')}") + + for key in ("attachments", "files"): + raw_value = message.get(key) + if not raw_value: + continue + try: + parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value + except Exception: + parsed = raw_value + items = parsed if isinstance(parsed, list) else [parsed] + for item in items: + if isinstance(item, str) and cls._looks_like_audio(item): + audio_refs.append(f"synology://file/{quote(item, safe='')}") + elif isinstance(item, dict): + url = item.get("url") or item.get("file_url") or item.get("audio_url") + if not isinstance(url, str): + continue + content_type = ( + item.get("content_type") + or item.get("mime_type") + or "" + ).lower() + name = ( + item.get("name") + or item.get("filename") + or "" + ).lower() + if content_type.startswith("audio/") or cls._looks_like_audio(url) or name.endswith(cls._AUDIO_SUFFIXES): + audio_refs.append(f"synology://file/{quote(url, safe='')}") + + deduped = [] + for audio_ref in audio_refs: + if audio_ref not in deduped: + deduped.append(audio_ref) + return deduped or None + @classmethod def _looks_like_image(cls, value: str) -> bool: if not value or not isinstance(value, str): @@ -160,6 +221,29 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): suffix in lowered for suffix in cls._IMAGE_SUFFIXES ) + @classmethod + def _looks_like_audio(cls, value: str) -> bool: + if not value or not isinstance(value, str): + return False + lowered = value.lower() + return lowered.startswith("http") and any( + suffix in lowered for suffix in cls._AUDIO_SUFFIXES + ) + + def download_synologychat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: + """ + 下载 Synology Chat 音频文件并返回原始字节 + """ + if not file_ref or not file_ref.startswith("synology://file/"): + return None + if not self.get_config(source): + return None + file_url = unquote(file_ref.replace("synology://file/", "", 1)) + resp = RequestUtils(timeout=30).get_res(file_url) + if resp and resp.content: + return resp.content + return None + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index 3ac04d71..93f3fad5 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -21,6 +21,20 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): ".tiff", ".svg", ) + _AUDIO_SUFFIXES = ( + ".mp3", + ".m4a", + ".wav", + ".ogg", + ".oga", + ".opus", + ".aac", + ".amr", + ".flac", + ".mpga", + ".mpeg", + ".webm", + ) def init_module(self) -> None: """ @@ -118,6 +132,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): content_type = detail.get("content_type") or "" content = detail.get("content") images = self._extract_images(detail) + audio_refs = self._extract_audio_refs(detail) text = None if content_type in ("text/plain", "text/markdown") and isinstance(content, str): text = content @@ -132,14 +147,15 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): userid = f"UID#{msg_body.get('from_uid')}" # 处理消息内容 - if (text or images) and userid: + if (text or images or audio_refs) and userid: logger.info( f"收到来自 {client_config.name} 的VoceChat消息:" - f"userid={userid}, text={text}, images={len(images) if images else 0}" + f"userid={userid}, text={text}, images={len(images) if images else 0}, " + f"audios={len(audio_refs) if audio_refs else 0}" ) return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name, userid=userid, username=userid, text=text or "", - images=images) + images=images, audio_refs=audio_refs) except Exception as err: logger.error(f"VoceChat消息处理发生错误:{str(err)}") return None @@ -182,6 +198,37 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): return [f"vocechat://file/{quote(file_path, safe='')}"] return None + @classmethod + def _extract_audio_refs(cls, detail: dict) -> Optional[List[str]]: + content_type = detail.get("content_type") or "" + if content_type != "vocechat/file": + return None + properties = detail.get("properties") or {} + mime_type = ( + properties.get("content_type") + or properties.get("mime_type") + or properties.get("contentType") + or "" + ).lower() + file_path = ( + properties.get("path") + or properties.get("file_path") + or properties.get("storage_path") + or detail.get("content") + ) + file_name = ( + properties.get("name") + or properties.get("filename") + or (str(file_path).rsplit("/", 1)[-1] if file_path else "") + ).lower() + + is_audio = mime_type.startswith("audio/") or file_name.endswith(cls._AUDIO_SUFFIXES) + if not is_audio: + return None + if isinstance(file_path, str) and file_path: + return [f"vocechat://file/{quote(file_path, safe='')}"] + return None + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 @@ -255,3 +302,22 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): return None file_path = unquote(image_ref.replace("vocechat://file/", "", 1)) return client.download_file_to_data_url(file_path) + + def download_vocechat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: + """ + 下载 VoceChat 文件并返回原始字节 + """ + if not file_ref or not file_ref.startswith("vocechat://file/"): + return None + client_config = self.get_config(source) + if not client_config: + return None + client: VoceChat = self.get_instance(client_config.name) + if not client: + return None + file_path = unquote(file_ref.replace("vocechat://file/", "", 1)) + file_data = client.download_file(file_path) + if file_data: + content, _ = file_data + return content + return None diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index 4db7532b..6ad64470 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -3,6 +3,7 @@ import json import unittest from types import SimpleNamespace from unittest.mock import AsyncMock, Mock, patch +from urllib.parse import quote from telebot import apihelper @@ -10,6 +11,7 @@ from app.agent.tools.impl.send_message import SendMessageInput from app.agent import MoviePilotAgent, AgentChain from app.chain.message import MessageChain from app.core.config import settings +from app.helper.voice import VoiceHelper from app.modules.discord import DiscordModule from app.modules.qqbot import QQBotModule from app.modules.slack import SlackModule @@ -203,6 +205,56 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影") self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"]) + def test_transcribe_audio_refs_supports_new_channel_refs(self): + chain = MessageChain() + audio_refs = [ + "slack://file/" + quote("https://files.slack.com/test.mp3", safe=""), + "discord://file/" + quote("https://cdn.discordapp.com/voice.ogg", safe=""), + "qq://file/" + quote("https://example.com/qq-voice.ogg", safe=""), + "vocechat://file/%2Fuploads%2Fvoice.ogg", + "synology://file/" + quote("https://example.com/synology-voice.wav", safe=""), + ] + + with patch.object(VoiceHelper, "is_available", return_value=True), patch.object( + chain, + "run_module", + side_effect=[b"slack", b"discord", b"qq", b"vocechat", b"synology"], + ) as run_module, patch.object( + VoiceHelper, + "transcribe_bytes", + side_effect=["slack text", "discord text", "qq text", "vocechat text", "synology text"], + ) as transcribe_bytes: + result = chain._transcribe_audio_refs( + audio_refs=audio_refs, + channel=MessageChannel.Slack, + source="mixed-source", + ) + + self.assertEqual( + result, + "slack text\ndiscord text\nqq text\nvocechat text\nsynology text", + ) + self.assertEqual( + [call.args[0] for call in run_module.call_args_list], + [ + "download_slack_file_bytes", + "download_discord_file_bytes", + "download_qq_file_bytes", + "download_vocechat_file_bytes", + "download_synologychat_file_bytes", + ], + ) + self.assertEqual( + [call.kwargs["filename"] for call in transcribe_bytes.call_args_list], + [ + "test.mp3", + "voice.ogg", + "qq-voice.ogg", + "voice.ogg", + "synology-voice.wav", + ], + ) + def test_agent_send_agent_message_does_not_auto_convert_to_voice(self): agent = MoviePilotAgent( session_id="session-1", @@ -240,7 +292,7 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(images, ["data:image/png;base64,abc123"]) run_module.assert_called_once_with( - "download_file_to_data_url", + "download_slack_file_to_data_url", file_url="https://files.slack.com/files-pri/T1-F1/test.png", source="slack-test", ) @@ -253,7 +305,7 @@ class AgentImageSupportTest(unittest.TestCase): with patch.object( module, "get_config", return_value=SimpleNamespace(name="slack-test") ), patch.object(module, "get_instance", return_value=client): - data_url = module.download_file_to_data_url( + data_url = module.download_slack_file_to_data_url( "https://files.slack.com/files-pri/T1-F1/test.png", "slack-test", ) @@ -263,6 +315,28 @@ class AgentImageSupportTest(unittest.TestCase): f"data:image/png;base64,{base64.b64encode(b'png-binary').decode()}", ) + def test_slack_extract_audio_refs_returns_private_file_refs(self): + audio_refs = SlackModule._extract_audio_refs( + { + "files": [ + { + "type": "audio", + "filetype": "mp3", + "mimetype": "audio/mpeg", + "url_private": "https://files.slack.com/files-pri/T1-F1/test.mp3", + } + ] + } + ) + + self.assertEqual( + audio_refs, + [ + "slack://file/" + + quote("https://files.slack.com/files-pri/T1-F1/test.mp3", safe="") + ], + ) + def test_send_message_input_accepts_image_only_payload(self): payload = SendMessageInput( explanation="send poster image", @@ -285,6 +359,27 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(images, ["https://cdn.discordapp.com/test.png"]) + def test_discord_extract_audio_refs_supports_attachment_content_type(self): + audio_refs = DiscordModule._extract_audio_refs( + { + "attachments": [ + { + "content_type": "audio/ogg", + "filename": "voice.ogg", + "url": "https://cdn.discordapp.com/voice.ogg", + } + ] + } + ) + + self.assertEqual( + audio_refs, + [ + "discord://file/" + + quote("https://cdn.discordapp.com/voice.ogg", safe="") + ], + ) + def test_discord_send_direct_message_returns_chat_id(self): module = DiscordModule() client = Mock() @@ -464,6 +559,41 @@ class AgentImageSupportTest(unittest.TestCase): ["vocechat://file/%2Fuploads%2Fposter.png"], ) + def test_vocechat_message_parser_extracts_audio_file_payload(self): + module = VoceChatModule() + body = json.dumps( + { + "detail": { + "type": "normal", + "content_type": "vocechat/file", + "content": "/uploads/voice.ogg", + "properties": {"content_type": "audio/ogg"}, + }, + "from_uid": 7910, + "target": {"gid": 2}, + } + ) + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace( + name="vocechat-test", config={"channel_id": "2"} + ), + ): + message = module.message_parser( + source="vocechat-test", + body=body, + form={}, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual( + message.audio_refs, + ["vocechat://file/%2Fuploads%2Fvoice.ogg"], + ) + def test_vocechat_post_message_passes_image_and_correct_target(self): module = VoceChatModule() client = Mock() @@ -520,6 +650,37 @@ class AgentImageSupportTest(unittest.TestCase): self.assertIsNotNone(message) self.assertEqual(message.images, ["https://example.com/qq-image.png"]) + def test_qq_message_parser_accepts_audio_only_attachment(self): + module = QQBotModule() + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace(name="qq-test", config={}), + ): + message = module.message_parser( + source="qq-test", + body={ + "type": "C2C_MESSAGE_CREATE", + "author": {"user_openid": "qq-user"}, + "attachments": [ + { + "content_type": "audio/ogg", + "filename": "voice.ogg", + "url": "https://example.com/qq-voice.ogg", + } + ], + }, + form={}, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual( + message.audio_refs, + ["qq://file/" + quote("https://example.com/qq-voice.ogg", safe="")], + ) + def test_synology_message_parser_accepts_image_only_form(self): module = SynologyChatModule() @@ -547,5 +708,42 @@ class AgentImageSupportTest(unittest.TestCase): self.assertIsNotNone(message) self.assertEqual(message.images, ["https://example.com/image.png"]) + def test_synology_message_parser_accepts_audio_only_form(self): + module = SynologyChatModule() + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace(name="synology-test", config={}), + ), patch.object( + module, + "get_instance", + return_value=SimpleNamespace(check_token=lambda token: token == "token-1"), + ): + message = module.message_parser( + source="synology-test", + body={}, + form={ + "token": "token-1", + "user_id": "42", + "username": "tester", + "attachments": json.dumps( + [ + { + "url": "https://example.com/voice.ogg", + "content_type": "audio/ogg", + } + ] + ), + }, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual( + message.audio_refs, + ["synology://file/" + quote("https://example.com/voice.ogg", safe="")], + ) + if __name__ == "__main__": unittest.main()