feat(agent): add audio message extraction and download support for Slack, QQ, Discord, SynologyChat, and VoceChat

This commit is contained in:
jxxghp
2026-04-13 08:36:57 +08:00
parent 8d938c2273
commit e09f9ad009
7 changed files with 601 additions and 15 deletions

View File

@@ -3,6 +3,7 @@ import re
import time
from datetime import datetime, timedelta
from typing import Any, Optional, Dict, Union, List
from urllib.parse import unquote
import base64
@@ -1338,8 +1339,39 @@ class MessageChain(ChainBase):
"download_wechat_media_bytes", media_ref=audio_ref, source=source
)
filename = "input.amr"
elif audio_ref.startswith("slack://file/"):
content = self.run_module(
"download_slack_file_bytes", file_ref=audio_ref, source=source
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
elif audio_ref.startswith("discord://file/"):
content = self.run_module(
"download_discord_file_bytes", file_ref=audio_ref, source=source
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
elif audio_ref.startswith("qq://file/"):
content = self.run_module(
"download_qq_file_bytes", file_ref=audio_ref, source=source
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
elif audio_ref.startswith("vocechat://file/"):
content = self.run_module(
"download_vocechat_file_bytes", file_ref=audio_ref, source=source
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
elif audio_ref.startswith("synology://file/"):
content = self.run_module(
"download_synologychat_file_bytes",
file_ref=audio_ref,
source=source,
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
elif audio_ref.startswith("wxbot://voice"):
continue
elif audio_ref.startswith("http"):
resp = RequestUtils(timeout=30).get_res(audio_ref)
content = resp.content if resp and resp.content else None
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
else:
logger.debug(
"暂不支持的语音引用: channel=%s, source=%s, ref=%s",
@@ -1349,6 +1381,15 @@ class MessageChain(ChainBase):
)
continue
if not content:
logger.warning(
"语音下载失败,跳过识别: channel=%s, source=%s, ref=%s",
channel.value if channel else None,
source,
audio_ref,
)
continue
transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename)
if transcript:
transcripts.append(transcript)
@@ -1364,6 +1405,23 @@ class MessageChain(ChainBase):
return "\n".join(transcripts).strip() if transcripts else None
@staticmethod
def _guess_audio_filename(audio_ref: str, default: str = "input.ogg") -> str:
"""
根据引用中的扩展名推测音频文件名,便于 STT 服务识别格式。
"""
if not audio_ref:
return default
raw_ref = unquote(audio_ref).split("?", 1)[0].split("#", 1)[0]
match = re.search(
r"([^/]+\.(mp3|m4a|wav|ogg|oga|opus|aac|amr|flac|mpga|mpeg|webm))$",
raw_ref,
flags=re.IGNORECASE,
)
if match:
return match.group(1)
return default
def _download_images_to_base64(
self, images: List[str], channel: MessageChannel, source: str
) -> List[str]:

View File

@@ -1,4 +1,5 @@
import json
from urllib.parse import quote, unquote
from typing import Optional, Union, List, Tuple, Any
from app.core.context import MediaInfo, Context
@@ -6,6 +7,7 @@ from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
from app.schemas.types import ModuleType
from app.utils.http import RequestUtils
try:
from app.modules.discord.discord import Discord
@@ -25,6 +27,20 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
".tiff",
".svg",
)
_AUDIO_SUFFIXES = (
".mp3",
".m4a",
".wav",
".ogg",
".oga",
".opus",
".aac",
".amr",
".flac",
".mpga",
".mpeg",
".webm",
)
def init_module(self) -> None:
"""
@@ -142,10 +158,12 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
text = msg_json.get("text")
chat_id = msg_json.get("chat_id")
images = self._extract_images(msg_json)
if (text or images) and userid:
audio_refs = self._extract_audio_refs(msg_json)
if (text or images or audio_refs) and userid:
logger.info(
f"收到来自 {client_config.name} 的 Discord 消息:"
f"userid={userid}, username={username}, text={text}, images={len(images) if images else 0}"
f"userid={userid}, username={username}, text={text}, "
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
)
return CommingMessage(
channel=MessageChannel.Discord,
@@ -155,6 +173,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
text=text,
chat_id=str(chat_id) if chat_id else None,
images=images,
audio_refs=audio_refs,
)
return None
@@ -181,6 +200,39 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
images.append(url)
return images if images else None
@classmethod
def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]:
"""
从Discord消息中提取音频URL
"""
attachments = msg_json.get("attachments", [])
if not attachments:
return None
audio_refs = []
for attachment in attachments:
url = attachment.get("url") or attachment.get("proxy_url")
if not url:
continue
content_type = (attachment.get("content_type") or "").lower()
filename = (attachment.get("filename") or "").lower()
if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES):
audio_refs.append(f"discord://file/{quote(url, safe='')}")
return audio_refs if audio_refs else None
def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载Discord附件并返回原始字节
"""
if not file_ref or not file_ref.startswith("discord://file/"):
return None
if not self.get_config(source):
return None
file_url = unquote(file_ref.replace("discord://file/", "", 1))
resp = RequestUtils(timeout=30).get_res(file_url)
if resp and resp.content:
return resp.content
return None
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送通知消息

View File

@@ -5,6 +5,7 @@ QQ Bot 通知模块
"""
import json
from urllib.parse import quote, unquote
from typing import Optional, List, Tuple, Union, Any
from app.core.context import MediaInfo, Context
@@ -13,6 +14,7 @@ from app.modules import _ModuleBase, _MessageBase
from app.modules.qqbot.qqbot import QQBot
from app.schemas import CommingMessage, MessageChannel, Notification
from app.schemas.types import ModuleType
from app.utils.http import RequestUtils
class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
@@ -28,6 +30,20 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
".tiff",
".svg",
)
_AUDIO_SUFFIXES = (
".mp3",
".m4a",
".wav",
".ogg",
".oga",
".opus",
".aac",
".amr",
".flac",
".mpga",
".mpeg",
".webm",
)
def init_module(self) -> None:
self.stop()
@@ -90,7 +106,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
msg_type = msg_body.get("type")
content = (msg_body.get("content") or "").strip()
images = self._extract_images(msg_body)
if not content and not images:
audio_refs = self._extract_audio_refs(msg_body)
if not content and not images and not audio_refs:
return None
if msg_type == "C2C_MESSAGE_CREATE":
@@ -100,7 +117,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
return None
logger.info(
f"收到 QQ 私聊消息: userid={user_openid}, "
f"text={(content or '')[:50]}..., images={len(images) if images else 0}"
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
f"audios={len(audio_refs) if audio_refs else 0}"
)
return CommingMessage(
channel=MessageChannel.QQ,
@@ -109,6 +127,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
username=user_openid,
text=content,
images=images,
audio_refs=audio_refs,
)
elif msg_type == "GROUP_AT_MESSAGE_CREATE":
author = msg_body.get("author", {})
@@ -118,7 +137,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
userid = f"group:{group_openid}" if group_openid else member_openid
logger.info(
f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, "
f"text={(content or '')[:50]}..., images={len(images) if images else 0}"
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
f"audios={len(audio_refs) if audio_refs else 0}"
)
return CommingMessage(
channel=MessageChannel.QQ,
@@ -127,6 +147,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
username=member_openid or group_openid,
text=content,
images=images,
audio_refs=audio_refs,
)
return None
@@ -175,6 +196,50 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
deduped.append(image)
return deduped or None
@classmethod
def _extract_audio_refs(cls, msg_body: dict) -> Optional[List[str]]:
audio_refs: List[str] = []
attachments = msg_body.get("attachments") or []
if isinstance(attachments, list):
for attachment in attachments:
if not isinstance(attachment, dict):
continue
url = attachment.get("url") or attachment.get("proxy_url")
if not url:
continue
content_type = (
attachment.get("content_type")
or attachment.get("mime_type")
or ""
).lower()
filename = (
attachment.get("filename")
or attachment.get("name")
or ""
).lower()
if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES):
audio_refs.append(f"qq://file/{quote(url, safe='')}")
deduped = []
for audio_ref in audio_refs:
if audio_ref not in deduped:
deduped.append(audio_ref)
return deduped or None
def download_qq_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载QQ音频附件并返回原始字节
"""
if not file_ref or not file_ref.startswith("qq://file/"):
return None
if not self.get_config(source):
return None
file_url = unquote(file_ref.replace("qq://file/", "", 1))
resp = RequestUtils(timeout=30).get_res(file_url)
if resp and resp.content:
return resp.content
return None
def post_message(self, message: Notification, **kwargs) -> None:
for conf in self.get_configs().values():
if not self.check_message(message, conf.name):

View File

@@ -1,5 +1,6 @@
import json
import re
from urllib.parse import quote, unquote
from typing import Optional, Union, List, Tuple, Any
from app.core.context import MediaInfo, Context
@@ -11,6 +12,21 @@ from app.schemas.types import ModuleType
class SlackModule(_ModuleBase, _MessageBase[Slack]):
_AUDIO_SUFFIXES = (
".mp3",
".m4a",
".wav",
".ogg",
".oga",
".opus",
".aac",
".amr",
".flac",
".mpga",
".mpeg",
".webm",
)
def init_module(self) -> None:
"""
初始化模块
@@ -204,11 +220,13 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
return None
if msg_json:
images = None
audio_refs = None
if msg_json.get("type") == "message":
userid = msg_json.get("user")
text = msg_json.get("text")
username = msg_json.get("user")
images = self._extract_images(msg_json)
audio_refs = self._extract_audio_refs(msg_json)
elif msg_json.get("type") == "block_actions":
userid = msg_json.get("user", {}).get("id")
callback_data = msg_json.get("actions")[0].get("value")
@@ -251,6 +269,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
).strip()
username = ""
images = self._extract_images(msg_json.get("event", {}))
audio_refs = self._extract_audio_refs(msg_json.get("event", {}))
elif msg_json.get("type") == "shortcut":
userid = msg_json.get("user", {}).get("id")
text = msg_json.get("callback_id")
@@ -262,7 +281,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
else:
return None
logger.info(
f"收到来自 {client_config.name} 的Slack消息userid={userid}, username={username}, text={text}, images={len(images) if images else 0}"
f"收到来自 {client_config.name} 的Slack消息userid={userid}, username={username}, "
f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
)
return CommingMessage(
channel=MessageChannel.Slack,
@@ -271,6 +291,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
username=username,
text=text,
images=images,
audio_refs=audio_refs,
)
return None
@@ -297,6 +318,29 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
images.append(url)
return images if images else None
@classmethod
def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]:
"""
从Slack消息中提取音频文件引用
"""
files = msg_json.get("files", [])
if not files:
return None
audio_refs = []
for file in files:
file_type = str(file.get("type", "")).lower()
file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}"
mime_type = str(file.get("mimetype", "")).lower()
if (
file_type == "audio"
or mime_type.startswith("audio/")
or file_ext in cls._AUDIO_SUFFIXES
):
url = file.get("url_private_download") or file.get("url_private")
if url:
audio_refs.append(f"slack://file/{quote(url, safe='')}")
return audio_refs if audio_refs else None
def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
"""
下载Slack文件并转为data URL
@@ -318,6 +362,25 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
return f"data:{mime_type};base64,{base64.b64encode(content).decode()}"
return None
def download_slack_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载Slack音频文件并返回原始字节
"""
if not file_ref or not file_ref.startswith("slack://file/"):
return None
config = self.get_config(source)
if not config:
return None
client = self.get_instance(config.name)
if not client:
return None
file_url = unquote(file_ref.replace("slack://file/", "", 1))
file_data = client.download_file(file_url)
if file_data:
content, _ = file_data
return content
return None
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息

View File

@@ -1,5 +1,6 @@
from typing import Optional, Union, List, Tuple, Any
import json
from typing import Optional, Union, List, Tuple, Any
from urllib.parse import quote, unquote
from app.core.context import MediaInfo, Context
from app.log import logger
@@ -7,6 +8,7 @@ from app.modules import _ModuleBase, _MessageBase
from app.modules.synologychat.synologychat import SynologyChat
from app.schemas import MessageChannel, CommingMessage, Notification
from app.schemas.types import ModuleType
from app.utils.http import RequestUtils
class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
@@ -20,6 +22,20 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
".tiff",
".svg",
)
_AUDIO_SUFFIXES = (
".mp3",
".m4a",
".wav",
".ogg",
".oga",
".opus",
".aac",
".amr",
".flac",
".mpga",
".mpeg",
".webm",
)
def init_module(self) -> None:
"""
@@ -108,14 +124,16 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
# 获取用户名
user_name = message.get("username")
images = self._extract_images(message)
if (text or images) and user_id:
audio_refs = self._extract_audio_refs(message)
if (text or images or audio_refs) and user_id:
logger.info(
f"收到来自 {client_config.name} 的SynologyChat消息"
f"userid={user_id}, username={user_name}, text={text}, images={len(images) if images else 0}"
f"userid={user_id}, username={user_name}, text={text}, "
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
)
return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name,
userid=user_id, username=user_name, text=text or "",
images=images)
images=images, audio_refs=audio_refs)
except Exception as err:
logger.debug(f"解析SynologyChat消息失败{str(err)}")
return None
@@ -151,6 +169,49 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
deduped.append(image)
return deduped or None
@classmethod
def _extract_audio_refs(cls, message: dict) -> Optional[List[str]]:
audio_refs = []
for key in ("audio_url", "voice_url", "file_url"):
value = message.get(key)
if isinstance(value, str) and cls._looks_like_audio(value):
audio_refs.append(f"synology://file/{quote(value, safe='')}")
for key in ("attachments", "files"):
raw_value = message.get(key)
if not raw_value:
continue
try:
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
except Exception:
parsed = raw_value
items = parsed if isinstance(parsed, list) else [parsed]
for item in items:
if isinstance(item, str) and cls._looks_like_audio(item):
audio_refs.append(f"synology://file/{quote(item, safe='')}")
elif isinstance(item, dict):
url = item.get("url") or item.get("file_url") or item.get("audio_url")
if not isinstance(url, str):
continue
content_type = (
item.get("content_type")
or item.get("mime_type")
or ""
).lower()
name = (
item.get("name")
or item.get("filename")
or ""
).lower()
if content_type.startswith("audio/") or cls._looks_like_audio(url) or name.endswith(cls._AUDIO_SUFFIXES):
audio_refs.append(f"synology://file/{quote(url, safe='')}")
deduped = []
for audio_ref in audio_refs:
if audio_ref not in deduped:
deduped.append(audio_ref)
return deduped or None
@classmethod
def _looks_like_image(cls, value: str) -> bool:
if not value or not isinstance(value, str):
@@ -160,6 +221,29 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
suffix in lowered for suffix in cls._IMAGE_SUFFIXES
)
@classmethod
def _looks_like_audio(cls, value: str) -> bool:
if not value or not isinstance(value, str):
return False
lowered = value.lower()
return lowered.startswith("http") and any(
suffix in lowered for suffix in cls._AUDIO_SUFFIXES
)
def download_synologychat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载 Synology Chat 音频文件并返回原始字节
"""
if not file_ref or not file_ref.startswith("synology://file/"):
return None
if not self.get_config(source):
return None
file_url = unquote(file_ref.replace("synology://file/", "", 1))
resp = RequestUtils(timeout=30).get_res(file_url)
if resp and resp.content:
return resp.content
return None
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息

View File

@@ -21,6 +21,20 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
".tiff",
".svg",
)
_AUDIO_SUFFIXES = (
".mp3",
".m4a",
".wav",
".ogg",
".oga",
".opus",
".aac",
".amr",
".flac",
".mpga",
".mpeg",
".webm",
)
def init_module(self) -> None:
"""
@@ -118,6 +132,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
content_type = detail.get("content_type") or ""
content = detail.get("content")
images = self._extract_images(detail)
audio_refs = self._extract_audio_refs(detail)
text = None
if content_type in ("text/plain", "text/markdown") and isinstance(content, str):
text = content
@@ -132,14 +147,15 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
userid = f"UID#{msg_body.get('from_uid')}"
# 处理消息内容
if (text or images) and userid:
if (text or images or audio_refs) and userid:
logger.info(
f"收到来自 {client_config.name} 的VoceChat消息"
f"userid={userid}, text={text}, images={len(images) if images else 0}"
f"userid={userid}, text={text}, images={len(images) if images else 0}, "
f"audios={len(audio_refs) if audio_refs else 0}"
)
return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name,
userid=userid, username=userid, text=text or "",
images=images)
images=images, audio_refs=audio_refs)
except Exception as err:
logger.error(f"VoceChat消息处理发生错误{str(err)}")
return None
@@ -182,6 +198,37 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
return [f"vocechat://file/{quote(file_path, safe='')}"]
return None
@classmethod
def _extract_audio_refs(cls, detail: dict) -> Optional[List[str]]:
content_type = detail.get("content_type") or ""
if content_type != "vocechat/file":
return None
properties = detail.get("properties") or {}
mime_type = (
properties.get("content_type")
or properties.get("mime_type")
or properties.get("contentType")
or ""
).lower()
file_path = (
properties.get("path")
or properties.get("file_path")
or properties.get("storage_path")
or detail.get("content")
)
file_name = (
properties.get("name")
or properties.get("filename")
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
).lower()
is_audio = mime_type.startswith("audio/") or file_name.endswith(cls._AUDIO_SUFFIXES)
if not is_audio:
return None
if isinstance(file_path, str) and file_path:
return [f"vocechat://file/{quote(file_path, safe='')}"]
return None
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息
@@ -255,3 +302,22 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
return None
file_path = unquote(image_ref.replace("vocechat://file/", "", 1))
return client.download_file_to_data_url(file_path)
def download_vocechat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载 VoceChat 文件并返回原始字节
"""
if not file_ref or not file_ref.startswith("vocechat://file/"):
return None
client_config = self.get_config(source)
if not client_config:
return None
client: VoceChat = self.get_instance(client_config.name)
if not client:
return None
file_path = unquote(file_ref.replace("vocechat://file/", "", 1))
file_data = client.download_file(file_path)
if file_data:
content, _ = file_data
return content
return None

View File

@@ -3,6 +3,7 @@ import json
import unittest
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from urllib.parse import quote
from telebot import apihelper
@@ -10,6 +11,7 @@ from app.agent.tools.impl.send_message import SendMessageInput
from app.agent import MoviePilotAgent, AgentChain
from app.chain.message import MessageChain
from app.core.config import settings
from app.helper.voice import VoiceHelper
from app.modules.discord import DiscordModule
from app.modules.qqbot import QQBotModule
from app.modules.slack import SlackModule
@@ -203,6 +205,56 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"])
def test_transcribe_audio_refs_supports_new_channel_refs(self):
chain = MessageChain()
audio_refs = [
"slack://file/" + quote("https://files.slack.com/test.mp3", safe=""),
"discord://file/" + quote("https://cdn.discordapp.com/voice.ogg", safe=""),
"qq://file/" + quote("https://example.com/qq-voice.ogg", safe=""),
"vocechat://file/%2Fuploads%2Fvoice.ogg",
"synology://file/" + quote("https://example.com/synology-voice.wav", safe=""),
]
with patch.object(VoiceHelper, "is_available", return_value=True), patch.object(
chain,
"run_module",
side_effect=[b"slack", b"discord", b"qq", b"vocechat", b"synology"],
) as run_module, patch.object(
VoiceHelper,
"transcribe_bytes",
side_effect=["slack text", "discord text", "qq text", "vocechat text", "synology text"],
) as transcribe_bytes:
result = chain._transcribe_audio_refs(
audio_refs=audio_refs,
channel=MessageChannel.Slack,
source="mixed-source",
)
self.assertEqual(
result,
"slack text\ndiscord text\nqq text\nvocechat text\nsynology text",
)
self.assertEqual(
[call.args[0] for call in run_module.call_args_list],
[
"download_slack_file_bytes",
"download_discord_file_bytes",
"download_qq_file_bytes",
"download_vocechat_file_bytes",
"download_synologychat_file_bytes",
],
)
self.assertEqual(
[call.kwargs["filename"] for call in transcribe_bytes.call_args_list],
[
"test.mp3",
"voice.ogg",
"qq-voice.ogg",
"voice.ogg",
"synology-voice.wav",
],
)
def test_agent_send_agent_message_does_not_auto_convert_to_voice(self):
agent = MoviePilotAgent(
session_id="session-1",
@@ -240,7 +292,7 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(images, ["data:image/png;base64,abc123"])
run_module.assert_called_once_with(
"download_file_to_data_url",
"download_slack_file_to_data_url",
file_url="https://files.slack.com/files-pri/T1-F1/test.png",
source="slack-test",
)
@@ -253,7 +305,7 @@ class AgentImageSupportTest(unittest.TestCase):
with patch.object(
module, "get_config", return_value=SimpleNamespace(name="slack-test")
), patch.object(module, "get_instance", return_value=client):
data_url = module.download_file_to_data_url(
data_url = module.download_slack_file_to_data_url(
"https://files.slack.com/files-pri/T1-F1/test.png",
"slack-test",
)
@@ -263,6 +315,28 @@ class AgentImageSupportTest(unittest.TestCase):
f"data:image/png;base64,{base64.b64encode(b'png-binary').decode()}",
)
def test_slack_extract_audio_refs_returns_private_file_refs(self):
audio_refs = SlackModule._extract_audio_refs(
{
"files": [
{
"type": "audio",
"filetype": "mp3",
"mimetype": "audio/mpeg",
"url_private": "https://files.slack.com/files-pri/T1-F1/test.mp3",
}
]
}
)
self.assertEqual(
audio_refs,
[
"slack://file/"
+ quote("https://files.slack.com/files-pri/T1-F1/test.mp3", safe="")
],
)
def test_send_message_input_accepts_image_only_payload(self):
payload = SendMessageInput(
explanation="send poster image",
@@ -285,6 +359,27 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(images, ["https://cdn.discordapp.com/test.png"])
def test_discord_extract_audio_refs_supports_attachment_content_type(self):
audio_refs = DiscordModule._extract_audio_refs(
{
"attachments": [
{
"content_type": "audio/ogg",
"filename": "voice.ogg",
"url": "https://cdn.discordapp.com/voice.ogg",
}
]
}
)
self.assertEqual(
audio_refs,
[
"discord://file/"
+ quote("https://cdn.discordapp.com/voice.ogg", safe="")
],
)
def test_discord_send_direct_message_returns_chat_id(self):
module = DiscordModule()
client = Mock()
@@ -464,6 +559,41 @@ class AgentImageSupportTest(unittest.TestCase):
["vocechat://file/%2Fuploads%2Fposter.png"],
)
def test_vocechat_message_parser_extracts_audio_file_payload(self):
module = VoceChatModule()
body = json.dumps(
{
"detail": {
"type": "normal",
"content_type": "vocechat/file",
"content": "/uploads/voice.ogg",
"properties": {"content_type": "audio/ogg"},
},
"from_uid": 7910,
"target": {"gid": 2},
}
)
with patch.object(
module,
"get_config",
return_value=SimpleNamespace(
name="vocechat-test", config={"channel_id": "2"}
),
):
message = module.message_parser(
source="vocechat-test",
body=body,
form={},
args={},
)
self.assertIsNotNone(message)
self.assertEqual(
message.audio_refs,
["vocechat://file/%2Fuploads%2Fvoice.ogg"],
)
def test_vocechat_post_message_passes_image_and_correct_target(self):
module = VoceChatModule()
client = Mock()
@@ -520,6 +650,37 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertIsNotNone(message)
self.assertEqual(message.images, ["https://example.com/qq-image.png"])
def test_qq_message_parser_accepts_audio_only_attachment(self):
module = QQBotModule()
with patch.object(
module,
"get_config",
return_value=SimpleNamespace(name="qq-test", config={}),
):
message = module.message_parser(
source="qq-test",
body={
"type": "C2C_MESSAGE_CREATE",
"author": {"user_openid": "qq-user"},
"attachments": [
{
"content_type": "audio/ogg",
"filename": "voice.ogg",
"url": "https://example.com/qq-voice.ogg",
}
],
},
form={},
args={},
)
self.assertIsNotNone(message)
self.assertEqual(
message.audio_refs,
["qq://file/" + quote("https://example.com/qq-voice.ogg", safe="")],
)
def test_synology_message_parser_accepts_image_only_form(self):
module = SynologyChatModule()
@@ -547,5 +708,42 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertIsNotNone(message)
self.assertEqual(message.images, ["https://example.com/image.png"])
def test_synology_message_parser_accepts_audio_only_form(self):
module = SynologyChatModule()
with patch.object(
module,
"get_config",
return_value=SimpleNamespace(name="synology-test", config={}),
), patch.object(
module,
"get_instance",
return_value=SimpleNamespace(check_token=lambda token: token == "token-1"),
):
message = module.message_parser(
source="synology-test",
body={},
form={
"token": "token-1",
"user_id": "42",
"username": "tester",
"attachments": json.dumps(
[
{
"url": "https://example.com/voice.ogg",
"content_type": "audio/ogg",
}
]
),
},
args={},
)
self.assertIsNotNone(message)
self.assertEqual(
message.audio_refs,
["synology://file/" + quote("https://example.com/voice.ogg", safe="")],
)
if __name__ == "__main__":
unittest.main()