diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index c2686f89..18595596 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -67,16 +67,40 @@ async def user_message(background_tasks: BackgroundTasks, request: Request, @router.post("/web", summary="接收WEB消息", response_model=schemas.Response) -def web_message(text: str, current_user: User = Depends(get_current_active_superuser)): +async def web_message( + request: Request, + text: Optional[str] = None, + current_user: User = Depends(get_current_active_superuser), +): """ WEB消息响应 """ + images = None + content_type = request.headers.get("content-type", "") + if "application/json" in content_type: + try: + payload = await request.json() + except Exception: + payload = None + if isinstance(payload, dict): + text = payload.get("text", text) + image = payload.get("image") + images = payload.get("images") + if image: + if isinstance(images, list): + images = [*images, image] + else: + images = [image] + elif isinstance(images, str): + images = [images] + MessageChain().handle_message( channel=MessageChannel.Web, source=current_user.name, userid=current_user.name, username=current_user.name, - text=text + text=text or "", + images=images, ) return schemas.Response(success=True) diff --git a/app/chain/message.py b/app/chain/message.py index 1817c8ec..cf2e63fb 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -1371,7 +1371,15 @@ class MessageChain(ChainBase): base64_images = [] for img in images: try: - if img.startswith("tg://file_id/"): + if img.startswith("data:"): + base64_images.append(img) + logger.info( + "图片无需下载: channel=%s, source=%s, input=%s", + channel.value if channel else None, + source, + self._summarize_images([img])[0], + ) + elif img.startswith("tg://file_id/"): file_id = img.replace("tg://file_id/", "") base64_data = self.run_module( "download_file_to_base64", file_id=file_id, source=source @@ -1384,6 +1392,23 @@ class MessageChain(ChainBase): source, img, ) + elif img.startswith("wxwork://media_id/") or img.startswith( + "wxbot://image/" + ): + data_url = self.run_module( + "download_wechat_image_to_data_url", + image_ref=img, + source=source, + ) + if data_url: + base64_images.append(data_url) + logger.info( + "图片下载成功: channel=%s, source=%s, input=%s, output=%s", + channel.value if channel else None, + source, + img, + self._summarize_images([data_url])[0], + ) elif channel == MessageChannel.Slack: data_url = self.run_module( "download_file_to_data_url", file_url=img, source=source @@ -1397,6 +1422,21 @@ class MessageChain(ChainBase): img, self._summarize_images([data_url])[0], ) + elif img.startswith("vocechat://file/"): + data_url = self.run_module( + "download_vocechat_image_to_data_url", + image_ref=img, + source=source, + ) + if data_url: + base64_images.append(data_url) + logger.info( + "图片下载成功: channel=%s, source=%s, input=%s, output=%s", + channel.value if channel else None, + source, + img, + self._summarize_images([data_url])[0], + ) elif img.startswith("http"): resp = RequestUtils(timeout=30).get_res(img) if resp and resp.content: diff --git a/app/modules/discord/__init__.py b/app/modules/discord/__init__.py index c1d9ce1d..43f1b99e 100644 --- a/app/modules/discord/__init__.py +++ b/app/modules/discord/__init__.py @@ -15,6 +15,17 @@ except Exception as err: # ImportError or other load issues class DiscordModule(_ModuleBase, _MessageBase[Discord]): + _IMAGE_SUFFIXES = ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".tiff", + ".svg", + ) + def init_module(self) -> None: """ 初始化模块 @@ -157,10 +168,17 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): return None images = [] for attachment in attachments: - if attachment.get("type") == "image": - url = attachment.get("url") - if url: - images.append(url) + url = attachment.get("url") or attachment.get("proxy_url") + if not url: + continue + content_type = (attachment.get("content_type") or "").lower() + filename = (attachment.get("filename") or "").lower() + if ( + attachment.get("type") == "image" + or content_type.startswith("image/") + or filename.endswith(DiscordModule._IMAGE_SUFFIXES) + ): + images.append(url) return images if images else None def post_message(self, message: Notification, **kwargs) -> None: @@ -363,15 +381,22 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): userid=userid, ) if result: - success, message_id = ( + success, response_data = ( (result[0], result[1]) if isinstance(result, tuple) else (result, None) ) if success: + message_id = None + chat_id = None + if isinstance(response_data, dict): + message_id = response_data.get("message_id") + chat_id = response_data.get("chat_id") + elif response_data is not None: + message_id = str(response_data) return MessageResponse( message_id=str(message_id) if message_id else None, - chat_id=None, + chat_id=str(chat_id) if chat_id else None, channel=MessageChannel.Discord, source=conf.name, success=True, diff --git a/app/modules/discord/discord.py b/app/modules/discord/discord.py index 49e56c2c..e2ddb858 100644 --- a/app/modules/discord/discord.py +++ b/app/modules/discord/discord.py @@ -126,6 +126,20 @@ class Discord: if isinstance(message.channel, discord.DMChannel) else "guild", } + if message.attachments: + payload["attachments"] = [ + { + "id": str(attachment.id), + "filename": attachment.filename, + "content_type": attachment.content_type, + "url": attachment.url, + "proxy_url": attachment.proxy_url, + "size": attachment.size, + "height": attachment.height, + "width": attachment.width, + } + for attachment in message.attachments + ] await self._post_to_ds(payload) @self._client.event @@ -346,7 +360,7 @@ class Discord: original_message_id: Optional[Union[int, str]], original_chat_id: Optional[str], mtype: Optional["NotificationType"] = None, - ) -> Tuple[bool, Optional[int]]: + ) -> Tuple[bool, Optional[Dict[str, str]]]: logger.debug( f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}" ) @@ -373,13 +387,29 @@ class Discord: embed=embed, view=view, ) - return success, int(original_message_id) if original_message_id else None + return ( + success, + { + "message_id": str(original_message_id), + "chat_id": str(original_chat_id), + } + if success and original_message_id and original_chat_id + else None, + ) logger.debug(f"[Discord] 发送新消息到频道: {channel}") try: sent_message = await channel.send(content=content, embed=embed, view=view) logger.debug("[Discord] 消息发送成功") - return True, sent_message.id if sent_message else None + return ( + True, + { + "message_id": str(sent_message.id), + "chat_id": str(channel.id), + } + if sent_message and getattr(channel, "id", None) is not None + else None, + ) except Exception as e: logger.error(f"[Discord] 发送消息到频道失败: {e}") return False, None diff --git a/app/modules/qqbot/__init__.py b/app/modules/qqbot/__init__.py index 384ecae9..e284bfba 100644 --- a/app/modules/qqbot/__init__.py +++ b/app/modules/qqbot/__init__.py @@ -18,6 +18,17 @@ from app.schemas.types import ModuleType class QQBotModule(_ModuleBase, _MessageBase[QQBot]): """QQ Bot 通知模块""" + _IMAGE_SUFFIXES = ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".tiff", + ".svg", + ) + def init_module(self) -> None: self.stop() super().init_service(service_name=QQBot.__name__.lower(), service_type=QQBot) @@ -78,7 +89,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): msg_type = msg_body.get("type") content = (msg_body.get("content") or "").strip() - if not content: + images = self._extract_images(msg_body) + if not content and not images: return None if msg_type == "C2C_MESSAGE_CREATE": @@ -86,13 +98,17 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): user_openid = author.get("user_openid", "") if not user_openid: return None - logger.info(f"收到 QQ 私聊消息: userid={user_openid}, text={content[:50]}...") + logger.info( + f"收到 QQ 私聊消息: userid={user_openid}, " + f"text={(content or '')[:50]}..., images={len(images) if images else 0}" + ) return CommingMessage( channel=MessageChannel.QQ, source=client_config.name, userid=user_openid, username=user_openid, text=content, + images=images, ) elif msg_type == "GROUP_AT_MESSAGE_CREATE": author = msg_body.get("author", {}) @@ -100,16 +116,65 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): group_openid = msg_body.get("group_openid", "") # 群聊用 group:group_openid 作为 userid,便于回复时识别 userid = f"group:{group_openid}" if group_openid else member_openid - logger.info(f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, text={content[:50]}...") + logger.info( + f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, " + f"text={(content or '')[:50]}..., images={len(images) if images else 0}" + ) return CommingMessage( channel=MessageChannel.QQ, source=client_config.name, userid=userid, username=member_openid or group_openid, text=content, + images=images, ) return None + @classmethod + def _extract_images(cls, msg_body: dict) -> Optional[List[str]]: + images: List[str] = [] + attachments = msg_body.get("attachments") or [] + if isinstance(attachments, list): + for attachment in attachments: + if not isinstance(attachment, dict): + continue + url = attachment.get("url") or attachment.get("proxy_url") + if not url: + continue + content_type = ( + attachment.get("content_type") + or attachment.get("mime_type") + or "" + ).lower() + filename = ( + attachment.get("filename") + or attachment.get("name") + or "" + ).lower() + if content_type.startswith("image/") or filename.endswith(cls._IMAGE_SUFFIXES): + images.append(url) + + for key in ("image", "image_url", "pic_url"): + value = msg_body.get(key) + if isinstance(value, str) and value.startswith("http"): + images.append(value) + + extra_images = msg_body.get("images") + if isinstance(extra_images, list): + for item in extra_images: + if isinstance(item, str) and item.startswith("http"): + images.append(item) + elif isinstance(item, dict): + url = item.get("url") or item.get("image_url") + if isinstance(url, str) and url.startswith("http"): + images.append(url) + + deduped = [] + for image in images: + if image not in deduped: + deduped.append(image) + return deduped or None + def post_message(self, message: Notification, **kwargs) -> None: for conf in self.get_configs().values(): if not self.check_message(message, conf.name): diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index 78c004ab..12fca4de 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -1,4 +1,5 @@ from typing import Optional, Union, List, Tuple, Any +import json from app.core.context import MediaInfo, Context from app.log import logger @@ -9,6 +10,16 @@ from app.schemas.types import ModuleType class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): + _IMAGE_SUFFIXES = ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".tiff", + ".svg", + ) def init_module(self) -> None: """ @@ -96,15 +107,59 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): user_id = int(message.get("user_id")) # 获取用户名 user_name = message.get("username") - if text and user_id: - logger.info(f"收到来自 {client_config.name} 的SynologyChat消息:" - f"userid={user_id}, username={user_name}, text={text}") + images = self._extract_images(message) + if (text or images) and user_id: + logger.info( + f"收到来自 {client_config.name} 的SynologyChat消息:" + f"userid={user_id}, username={user_name}, text={text}, images={len(images) if images else 0}" + ) return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name, - userid=user_id, username=user_name, text=text) + userid=user_id, username=user_name, text=text or "", + images=images) except Exception as err: logger.debug(f"解析SynologyChat消息失败:{str(err)}") return None + @classmethod + def _extract_images(cls, message: dict) -> Optional[List[str]]: + images = [] + for key in ("file_url", "image_url", "pic_url"): + value = message.get(key) + if isinstance(value, str) and cls._looks_like_image(value): + images.append(value) + + for key in ("attachments", "files"): + raw_value = message.get(key) + if not raw_value: + continue + try: + parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value + except Exception: + parsed = raw_value + items = parsed if isinstance(parsed, list) else [parsed] + for item in items: + if isinstance(item, str) and cls._looks_like_image(item): + images.append(item) + elif isinstance(item, dict): + url = item.get("url") or item.get("file_url") or item.get("image_url") + if isinstance(url, str) and cls._looks_like_image(url): + images.append(url) + + deduped = [] + for image in images: + if image not in deduped: + deduped.append(image) + return deduped or None + + @classmethod + def _looks_like_image(cls, value: str) -> bool: + if not value or not isinstance(value, str): + return False + lowered = value.lower() + return lowered.startswith("http") and any( + suffix in lowered for suffix in cls._IMAGE_SUFFIXES + ) + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index 85382f3a..3ac04d71 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -1,4 +1,5 @@ import json +from urllib.parse import quote, unquote from typing import Optional, Union, List, Tuple, Any, Dict from app.core.context import Context, MediaInfo @@ -10,6 +11,16 @@ from app.schemas.types import ModuleType class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): + _IMAGE_SUFFIXES = ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".tiff", + ".svg", + ) def init_module(self) -> None: """ @@ -99,12 +110,17 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): msg_body = json.loads(body) # 类型 msg_type = msg_body.get("detail", {}).get("type") - if msg_type != "normal": - # 非新消息 + if msg_type not in ("normal", "reply"): + # 非新消息/回复 return None logger.debug(f"收到VoceChat请求:{msg_body}") - # 文本内容 - content = msg_body.get("detail", {}).get("content") + detail = msg_body.get("detail", {}) or {} + content_type = detail.get("content_type") or "" + content = detail.get("content") + images = self._extract_images(detail) + text = None + if content_type in ("text/plain", "text/markdown") and isinstance(content, str): + text = content # 用户ID gid = msg_body.get("target", {}).get("gid") channel_id = client_config.config.get("channel_id") @@ -116,14 +132,56 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): userid = f"UID#{msg_body.get('from_uid')}" # 处理消息内容 - if content and userid: - logger.info(f"收到来自 {client_config.name} 的VoceChat消息:userid={userid}, text={content}") + if (text or images) and userid: + logger.info( + f"收到来自 {client_config.name} 的VoceChat消息:" + f"userid={userid}, text={text}, images={len(images) if images else 0}" + ) return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name, - userid=userid, username=userid, text=content) + userid=userid, username=userid, text=text or "", + images=images) except Exception as err: logger.error(f"VoceChat消息处理发生错误:{str(err)}") return None + @classmethod + def _extract_images(cls, detail: dict) -> Optional[List[str]]: + content_type = detail.get("content_type") or "" + if content_type != "vocechat/file": + return None + properties = detail.get("properties") or {} + mime_type = ( + properties.get("content_type") + or properties.get("mime_type") + or properties.get("contentType") + or "" + ).lower() + file_path = ( + properties.get("path") + or properties.get("file_path") + or properties.get("storage_path") + or detail.get("content") + ) + direct_url = ( + properties.get("url") + or properties.get("download_url") + or properties.get("file_url") + ) + file_name = ( + properties.get("name") + or properties.get("filename") + or (str(file_path).rsplit("/", 1)[-1] if file_path else "") + ).lower() + + is_image = mime_type.startswith("image/") or file_name.endswith(cls._IMAGE_SUFFIXES) + if not is_image: + return None + if isinstance(direct_url, str) and direct_url.startswith("http"): + return [direct_url] + if isinstance(file_path, str) and file_path: + return [f"vocechat://file/{quote(file_path, safe='')}"] + return None + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 @@ -136,11 +194,11 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): targets = message.targets userid = message.userid if not message.userid and targets: - userid = targets.get('telegram_userid') + userid = targets.get('vocechat_userid') client: VoceChat = self.get_instance(conf.name) if client: client.send_msg(title=message.title, text=message.text, - userid=userid, link=message.link) + image=message.image, userid=userid, link=message.link) def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ @@ -182,3 +240,18 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): def register_commands(self, commands: Dict[str, dict]): pass + + def download_vocechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]: + """ + 下载 VoceChat 图片并转换为 data URL + """ + if not image_ref or not image_ref.startswith("vocechat://file/"): + return None + client_config = self.get_config(source) + if not client_config: + return None + client: VoceChat = self.get_instance(client_config.name) + if not client: + return None + file_path = unquote(image_ref.replace("vocechat://file/", "", 1)) + return client.download_file_to_data_url(file_path) diff --git a/app/modules/vocechat/vocechat.py b/app/modules/vocechat/vocechat.py index 10594592..d5834dec 100644 --- a/app/modules/vocechat/vocechat.py +++ b/app/modules/vocechat/vocechat.py @@ -1,6 +1,8 @@ import re import threading -from typing import Optional, List +import base64 +from typing import Optional, List, Tuple +from urllib.parse import quote from app.core.context import MediaInfo, Context from app.core.metainfo import MetaInfo @@ -21,6 +23,7 @@ class VoceChat: _channel_id = None # 请求对象 _client = None + _file_client = None def __init__(self, VOCECHAT_HOST: Optional[str] = None, VOCECHAT_API_KEY: Optional[str] = None, VOCECHAT_CHANNEL_ID: Optional[str] = None, **kwargs): """ @@ -29,12 +32,11 @@ class VoceChat: if not VOCECHAT_HOST or not VOCECHAT_API_KEY or not VOCECHAT_CHANNEL_ID: logger.error("VoceChat配置不完整!") return - self._host = VOCECHAT_HOST - if self._host: - if not self._host.endswith("/"): - self._host += "/" - if not self._host.startswith("http"): - self._playhost = "http://" + self._host + self._host = VOCECHAT_HOST.strip() + if self._host and not self._host.startswith("http"): + self._host = f"http://{self._host}" + if self._host and not self._host.endswith("/"): + self._host += "/" self._apikey = VOCECHAT_API_KEY self._channel_id = VOCECHAT_CHANNEL_ID if self._apikey and self._host and self._channel_id: @@ -43,6 +45,10 @@ class VoceChat: "x-api-key": self._apikey, "accept": "application/json; charset=utf-8" }) + self._file_client = RequestUtils(headers={ + "x-api-key": self._apikey, + "accept": "*/*" + }) def get_state(self): """ @@ -61,6 +67,7 @@ class VoceChat: return result.json() def send_msg(self, title: str, text: Optional[str] = None, + image: Optional[str] = None, userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]: """ 微信消息发送入口,支持文本、图片、链接跳转、指定发送对象 @@ -83,6 +90,9 @@ class VoceChat: else: caption = f"**{title}**" + if image: + caption = f"{caption}\n![]({image})" + if link: caption = f"{caption}\n[查看详情]({link})" @@ -97,6 +107,46 @@ class VoceChat: logger.error(f"发送消息失败:{msg_e}") return False + @staticmethod + def _guess_mime_type(content: bytes, default: str = "image/jpeg") -> str: + if not content: + return default + if content.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if content.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if content.startswith((b"GIF87a", b"GIF89a")): + return "image/gif" + if content.startswith(b"BM"): + return "image/bmp" + if content.startswith(b"RIFF") and b"WEBP" in content[:16]: + return "image/webp" + return default + + def download_file(self, path: str) -> Optional[Tuple[bytes, str]]: + """ + 下载 VoceChat 文件资源 + """ + if not path or not self._file_client: + return None + req_url = f"{self._host}api/resource/file?path={quote(path, safe='')}" + try: + res = self._file_client.get_res(req_url) + except Exception as err: + logger.error(f"VoceChat 文件下载失败:{err}") + return None + if not res or not res.content: + return None + mime_type = (res.headers.get("Content-Type") or "").split(";")[0].strip() + return res.content, mime_type or self._guess_mime_type(res.content) + + def download_file_to_data_url(self, path: str) -> Optional[str]: + file_data = self.download_file(path) + if not file_data: + return None + content, mime_type = file_data + return f"data:{mime_type};base64,{base64.b64encode(content).decode()}" + def send_medias_msg(self, title: str, medias: List[MediaInfo], userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]: """ diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index 81a54fd2..04d7c2cc 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -1,4 +1,6 @@ import copy +import json +import re import xml.dom.minidom from typing import Optional, Union, List, Tuple, Any, Dict @@ -103,7 +105,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): if not client_config: return None if self._is_bot_mode(client_config.config): - return None + return self._parse_bot_message(source=source, body=body, client_config=client_config) client: WeChat = self.get_instance(client_config.name) # URL参数 sVerifyMsgSig = args.get("msg_signature") @@ -163,6 +165,8 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): logger.warn(f"解析不到消息类型和用户ID") return None # 解析消息内容 + content = None + images = None if msg_type == "event" and event == "click": # 校验用户有权限执行交互命令 if client_config.config.get('WECHAT_ADMINS'): @@ -178,17 +182,85 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): # 文本消息 content = DomUtils.tag_value(root_node, "Content", default="") logger.info(f"收到来自 {client_config.name} 的微信消息:userid={user_id}, text={content}") + elif msg_type == "image": + media_id = DomUtils.tag_value(root_node, "MediaId") + pic_url = DomUtils.tag_value(root_node, "PicUrl") + if media_id: + images = [f"wxwork://media_id/{media_id}"] + elif pic_url: + images = [pic_url] + logger.info( + f"收到来自 {client_config.name} 的微信图片消息:userid={user_id}, images={len(images) if images else 0}" + ) else: return None - if content: + if content or images: # 处理消息内容 return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name, - userid=user_id, username=user_id, text=content) + userid=user_id, username=user_id, text=content or "", + images=images) except Exception as err: logger.error(f"微信消息处理发生错误:{str(err)}") return None + def _parse_bot_message(self, source: str, body: Any, client_config) -> Optional[CommingMessage]: + try: + if isinstance(body, bytes): + msg_json = json.loads(body) + elif isinstance(body, dict): + msg_json = body + else: + msg_json = json.loads(body) + while isinstance(msg_json, str): + msg_json = json.loads(msg_json) + except Exception as err: + logger.debug(f"解析企业微信智能机器人消息失败:{err}") + return None + + if not isinstance(msg_json, dict): + return None + + payload_body = msg_json.get("body") or {} + sender = ((payload_body.get("from") or {}).get("userid") or "").strip() + if not sender: + return None + if payload_body.get("chattype") == "group": + return None + + text = WeChatBot._extract_text_from_body(payload_body) + images = WeChatBot._extract_images_from_body(payload_body) + if text: + text = re.sub(r"@\S+", "", text).strip() + + if text and text.startswith("/") and client_config.config.get('WECHAT_ADMINS'): + wechat_admins = [ + admin.strip() + for admin in client_config.config.get('WECHAT_ADMINS', '').split(',') + if admin.strip() + ] + if wechat_admins and sender not in wechat_admins: + client: WeChatBot = self.get_instance(client_config.name) + if client: + client.send_msg(title="只有管理员才有权限执行此命令", userid=sender) + return None + + if not text and not images: + return None + + logger.info( + f"收到来自 {client_config.name} 的企业微信智能机器人消息:" + f"userid={sender}, text={text}, images={len(images) if images else 0}" + ) + return CommingMessage( + channel=MessageChannel.Wechat, + source=client_config.name, + userid=sender, + username=sender, + text=text or "", + images=images, + ) + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 @@ -210,6 +282,25 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): client.send_msg(title=message.title, text=message.text, image=message.image, userid=userid, link=message.link) + def download_wechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]: + """ + 下载企业微信渠道图片并转换为 data URL + """ + if not image_ref: + return None + client_config = self.get_config(source) + if not client_config: + return None + client = self.get_instance(client_config.name) + if not client: + return None + if image_ref.startswith("wxwork://media_id/") and hasattr(client, "download_media_to_data_url"): + media_id = image_ref.replace("wxwork://media_id/", "", 1) + return client.download_media_to_data_url(media_id) + if image_ref.startswith("wxbot://image/") and hasattr(client, "download_image_to_data_url"): + return client.download_image_to_data_url(image_ref) + return None + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 diff --git a/app/modules/wechat/wechat.py b/app/modules/wechat/wechat.py index 93c286a6..c42e7858 100644 --- a/app/modules/wechat/wechat.py +++ b/app/modules/wechat/wechat.py @@ -1,6 +1,7 @@ import json import re import threading +import base64 from datetime import datetime from typing import Optional, List, Dict @@ -43,6 +44,8 @@ class WeChat: _create_menu_url = "cgi-bin/menu/create?access_token={access_token}&agentid={agentid}" # 企业微信删除菜单URL _delete_menu_url = "cgi-bin/menu/delete?access_token={access_token}&agentid={agentid}" + # 企业微信下载媒体URL + _download_media_url = "cgi-bin/media/get?access_token={access_token}&media_id={media_id}" def __init__(self, WECHAT_CORPID: Optional[str] = None, WECHAT_APP_SECRET: Optional[str] = None, WECHAT_APP_ID: Optional[str] = None, WECHAT_PROXY: Optional[str] = None, **kwargs): @@ -62,6 +65,7 @@ class WeChat: self._token_url = UrlUtils.adapt_request_url(self._proxy, self._token_url) self._create_menu_url = UrlUtils.adapt_request_url(self._proxy, self._create_menu_url) self._delete_menu_url = UrlUtils.adapt_request_url(self._proxy, self._delete_menu_url) + self._download_media_url = UrlUtils.adapt_request_url(self._proxy, self._download_media_url) if self._corpid and self._appsecret and self._appid: self.__get_access_token() @@ -267,6 +271,58 @@ class WeChat: logger.error(f"发送消息失败:{e}") return False + @staticmethod + def _guess_mime_type(content: bytes, default: str = "image/jpeg") -> str: + """ + 根据文件头推断图片 MIME + """ + if not content: + return default + if content.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if content.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if content.startswith((b"GIF87a", b"GIF89a")): + return "image/gif" + if content.startswith(b"BM"): + return "image/bmp" + if content.startswith(b"RIFF") and b"WEBP" in content[:16]: + return "image/webp" + return default + + def download_media_to_data_url(self, media_id: str) -> Optional[str]: + """ + 下载企业微信媒体文件并转换为 data URL + """ + if not media_id: + return None + access_token = self.__get_access_token() + if not access_token: + logger.error("下载企业微信媒体失败:access_token 获取失败") + return None + req_url = self._download_media_url.format( + access_token=access_token, + media_id=media_id, + ) + try: + res = RequestUtils(timeout=30).get_res(req_url) + except Exception as err: + logger.error(f"下载企业微信媒体失败:{err}") + return None + if not res or not res.content: + return None + + content_type = (res.headers.get("Content-Type") or "").split(";")[0].strip() + if content_type == "application/json": + try: + logger.error(f"企业微信媒体下载失败:{res.json()}") + except Exception: + logger.error(f"企业微信媒体下载失败:{res.text}") + return None + + mime_type = self._guess_mime_type(res.content, content_type or "image/jpeg") + return f"data:{mime_type};base64,{base64.b64encode(res.content).decode()}" + def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None) -> Optional[bool]: """ 发送列表类消息 diff --git a/app/modules/wechat/wechatbot.py b/app/modules/wechat/wechatbot.py index f840d3f0..da349650 100644 --- a/app/modules/wechat/wechatbot.py +++ b/app/modules/wechat/wechatbot.py @@ -5,9 +5,11 @@ import re import threading import time import uuid +import base64 from typing import Optional, List, Dict, Tuple, Set import websocket +from Crypto.Cipher import AES from app.core.cache import FileCache from app.core.config import settings @@ -332,6 +334,116 @@ class WeChatBot: text = "\n".join(part for part in text_parts if part).strip() return text or None + @staticmethod + def _build_image_ref(image_payload: dict) -> Optional[str]: + if not image_payload or not isinstance(image_payload, dict): + return None + download_url = ( + image_payload.get("download_url") + or image_payload.get("url") + or image_payload.get("cdnurl") + ) + if not download_url: + return None + payload = { + "url": download_url, + "aeskey": image_payload.get("aeskey") + or image_payload.get("encoding_aes_key") + or image_payload.get("encrypt_key"), + "mime_type": image_payload.get("mime_type") + or image_payload.get("content_type"), + } + encoded = base64.urlsafe_b64encode( + json.dumps(payload, ensure_ascii=False).encode("utf-8") + ).decode("ascii").rstrip("=") + return f"wxbot://image/{encoded}" + + @classmethod + def _extract_images_from_body(cls, body: dict) -> Optional[List[str]]: + images: List[str] = [] + msgtype = body.get("msgtype") + + if msgtype == "image": + image_ref = cls._build_image_ref(body.get("image") or {}) + if image_ref: + images.append(image_ref) + elif msgtype == "mixed": + for item in (body.get("mixed") or {}).get("msg_item") or []: + if item.get("msgtype") != "image": + continue + image_ref = cls._build_image_ref(item.get("image") or {}) + if image_ref: + images.append(image_ref) + + quote = body.get("quote") or {} + if not images and quote.get("msgtype") == "image": + image_ref = cls._build_image_ref(quote.get("image") or {}) + if image_ref: + images.append(image_ref) + + return images or None + + @staticmethod + def _guess_mime_type(content: bytes, default: str = "image/jpeg") -> str: + if not content: + return default + if content.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if content.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if content.startswith((b"GIF87a", b"GIF89a")): + return "image/gif" + if content.startswith(b"BM"): + return "image/bmp" + if content.startswith(b"RIFF") and b"WEBP" in content[:16]: + return "image/webp" + return default + + def download_image_to_data_url(self, image_ref: str) -> Optional[str]: + if not image_ref or not image_ref.startswith("wxbot://image/"): + return None + encoded = image_ref.replace("wxbot://image/", "", 1) + try: + padding = "=" * (-len(encoded) % 4) + payload = json.loads( + base64.urlsafe_b64decode((encoded + padding).encode("ascii")).decode( + "utf-8" + ) + ) + except Exception as err: + logger.error(f"解析企业微信智能机器人图片引用失败:{err}") + return None + + download_url = payload.get("url") + if not download_url: + return None + + try: + resp = RequestUtils(timeout=30).get_res(download_url) + except Exception as err: + logger.error(f"下载企业微信智能机器人图片失败:{err}") + return None + if not resp or not resp.content: + return None + + content = resp.content + aes_key = payload.get("aeskey") + if aes_key: + try: + aes_bytes = base64.b64decode(aes_key + "=" * (-len(aes_key) % 4)) + cipher = AES.new(aes_bytes, AES.MODE_CBC, aes_bytes[:16]) + decrypted = cipher.decrypt(content) + padding_len = decrypted[-1] + if 0 < padding_len <= 32: + decrypted = decrypted[:-padding_len] + content = decrypted + except Exception as err: + logger.error(f"解密企业微信智能机器人图片失败:{err}") + return None + + mime_type = self._guess_mime_type(content, payload.get("mime_type") or "image/jpeg") + return f"data:{mime_type};base64,{base64.b64encode(content).decode()}" + def _handle_callback_message(self, payload: dict) -> None: body = payload.get("body") or {} sender = ((body.get("from") or {}).get("userid") or "").strip() @@ -343,20 +455,24 @@ class WeChatBot: return text = self._extract_text_from_body(body) - if not text: - return + images = self._extract_images_from_body(body) - text = re.sub(r"@\S+", "", text).strip() - if not text: + if text: + text = re.sub(r"@\S+", "", text).strip() + + if not text and not images: return self._remember_target(sender) - if text.startswith("/") and self._admins and sender not in self._admins: + if text and text.startswith("/") and self._admins and sender not in self._admins: self.send_msg(title="只有管理员才有权限执行此命令", userid=sender) return - logger.info(f"收到来自 {self._config_name} 的企业微信智能机器人消息:userid={sender}, text={text}") + logger.info( + f"收到来自 {self._config_name} 的企业微信智能机器人消息:" + f"userid={sender}, text={text}, images={len(images) if images else 0}" + ) self._forward_to_message_chain(payload) def _forward_to_message_chain(self, payload: dict) -> None: diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index 389138c2..cd875786 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -9,10 +9,16 @@ from telebot import apihelper from app.agent.tools.impl.send_message import SendMessageInput from app.chain.message import MessageChain from app.core.config import settings +from app.modules.discord import DiscordModule +from app.modules.qqbot import QQBotModule from app.modules.slack import SlackModule from app.modules.telegram.telegram import Telegram from app.modules.telegram import TelegramModule -from app.schemas import CommingMessage +from app.modules.synologychat import SynologyChatModule +from app.modules.vocechat import VoceChatModule +from app.modules.wechat import WechatModule +from app.modules.wechat.wechatbot import WeChatBot +from app.schemas import CommingMessage, Notification from app.schemas.types import MessageChannel @@ -190,5 +196,281 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(payload.image_url, "https://example.com/poster.png") + def test_discord_extract_images_supports_attachment_content_type(self): + images = DiscordModule._extract_images( + { + "attachments": [ + { + "content_type": "image/png", + "url": "https://cdn.discordapp.com/test.png", + } + ] + } + ) + + self.assertEqual(images, ["https://cdn.discordapp.com/test.png"]) + + def test_discord_send_direct_message_returns_chat_id(self): + module = DiscordModule() + client = Mock() + client.send_msg.return_value = ( + True, + {"message_id": "discord-msg-1", "chat_id": "discord-chat-1"}, + ) + + with patch.object( + module, + "get_configs", + return_value={"discord-test": SimpleNamespace(name="discord-test")}, + ), patch.object( + module, "check_message", return_value=True + ), patch.object( + module, "get_instance", return_value=client + ): + response = module.send_direct_message( + Notification(title="hi", userid="user-1") + ) + + self.assertIsNotNone(response) + self.assertEqual(response.message_id, "discord-msg-1") + self.assertEqual(response.chat_id, "discord-chat-1") + + def test_download_images_routes_wechat_refs_to_module_downloader(self): + chain = MessageChain() + + with patch.object( + chain, + "run_module", + return_value="data:image/png;base64,wechat123", + ) as run_module: + images = chain._download_images_to_base64( + images=["wxwork://media_id/media-1"], + channel=MessageChannel.Wechat, + source="wechat-test", + ) + + self.assertEqual(images, ["data:image/png;base64,wechat123"]) + run_module.assert_called_once_with( + "download_wechat_image_to_data_url", + image_ref="wxwork://media_id/media-1", + source="wechat-test", + ) + + def test_wechat_message_parser_extracts_image_media_id(self): + module = WechatModule() + xml_message = b""" + + + + + + + """ + crypt = Mock() + crypt.DecryptMsg.return_value = (0, xml_message) + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace( + name="wechat-test", + config={ + "WECHAT_TOKEN": "token", + "WECHAT_ENCODING_AESKEY": "encoding", + "WECHAT_CORPID": "corpid", + }, + ), + ), patch.object( + module, "get_instance", return_value=SimpleNamespace(send_msg=Mock()) + ), patch( + "app.modules.wechat.WXBizMsgCrypt", + return_value=crypt, + ): + message = module.message_parser( + source="wechat-test", + body=b"encrypted", + form={}, + args={"msg_signature": "sig", "timestamp": "1", "nonce": "n"}, + ) + + self.assertIsNotNone(message) + self.assertEqual(message.images, ["wxwork://media_id/media-1"]) + + def test_wechat_bot_parser_accepts_image_only_payload(self): + module = WechatModule() + body = json.dumps( + { + "body": { + "from": {"userid": "wxbot-user"}, + "msgtype": "image", + "image": { + "download_url": "https://example.com/encrypted-image", + "aeskey": "YWJjZGVmZw", + }, + } + } + ) + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace( + name="wechat-bot-test", config={"WECHAT_MODE": "bot"} + ), + ), patch.object( + module, "get_instance", return_value=SimpleNamespace(send_msg=Mock()) + ): + message = module.message_parser( + source="wechat-bot-test", + body=body, + form={}, + args={}, + ) + + self.assertIsNotNone(message) + self.assertTrue(message.images[0].startswith("wxbot://image/")) + + def test_wechat_bot_handles_image_only_callback(self): + bot = WeChatBot.__new__(WeChatBot) + bot._config_name = "wechat-bot-test" + bot._admins = [] + bot.send_msg = Mock() + bot._remember_target = Mock() + bot._forward_to_message_chain = Mock() + + payload = { + "body": { + "from": {"userid": "wxbot-user"}, + "msgtype": "image", + "image": { + "download_url": "https://example.com/encrypted-image", + "aeskey": "YWJjZGVmZw", + }, + } + } + + bot._handle_callback_message(payload) + + bot._remember_target.assert_called_once_with("wxbot-user") + bot._forward_to_message_chain.assert_called_once_with(payload) + + def test_vocechat_message_parser_extracts_image_file_payload(self): + module = VoceChatModule() + body = json.dumps( + { + "detail": { + "type": "normal", + "content_type": "vocechat/file", + "content": "/uploads/poster.png", + "properties": {"content_type": "image/png"}, + }, + "from_uid": 7910, + "target": {"gid": 2}, + } + ) + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace( + name="vocechat-test", config={"channel_id": "2"} + ), + ): + message = module.message_parser( + source="vocechat-test", + body=body, + form={}, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual( + message.images, + ["vocechat://file/%2Fuploads%2Fposter.png"], + ) + + def test_vocechat_post_message_passes_image_and_correct_target(self): + module = VoceChatModule() + client = Mock() + + with patch.object( + module, + "get_configs", + return_value={"vocechat-test": SimpleNamespace(name="vocechat-test")}, + ), patch.object( + module, "check_message", return_value=True + ), patch.object( + module, "get_instance", return_value=client + ): + module.post_message( + Notification( + title="poster", + image="https://example.com/poster.png", + targets={"vocechat_userid": "UID#100"}, + ) + ) + + client.send_msg.assert_called_once_with( + title="poster", + text=None, + image="https://example.com/poster.png", + userid="UID#100", + link=None, + ) + + def test_qq_message_parser_accepts_image_only_attachment(self): + module = QQBotModule() + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace(name="qq-test", config={}), + ): + message = module.message_parser( + source="qq-test", + body={ + "type": "C2C_MESSAGE_CREATE", + "author": {"user_openid": "qq-user"}, + "attachments": [ + { + "content_type": "image/png", + "url": "https://example.com/qq-image.png", + } + ], + }, + form={}, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual(message.images, ["https://example.com/qq-image.png"]) + + def test_synology_message_parser_accepts_image_only_form(self): + module = SynologyChatModule() + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace(name="synology-test", config={}), + ), patch.object( + module, + "get_instance", + return_value=SimpleNamespace(check_token=lambda token: token == "token-1"), + ): + message = module.message_parser( + source="synology-test", + body={}, + form={ + "token": "token-1", + "user_id": "42", + "username": "tester", + "file_url": "https://example.com/image.png", + }, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual(message.images, ["https://example.com/image.png"]) + if __name__ == "__main__": unittest.main()