"""语音能力辅助功能。""" from abc import ABC, abstractmethod from io import BytesIO from pathlib import Path from typing import Dict, Optional from uuid import uuid4 from app.core.config import settings from app.log import logger class VoiceProvider(ABC): """语音 provider 抽象层。""" MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024 @property @abstractmethod def name(self) -> str: """provider 名称。""" @abstractmethod def is_available_for_stt(self) -> bool: """是否可用于语音识别。""" @abstractmethod def is_available_for_tts(self) -> bool: """是否可用于语音合成。""" @abstractmethod def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: """将音频字节转成文字。""" @abstractmethod def synthesize_speech(self, text: str) -> Optional[Path]: """将文字转成语音文件。""" class OpenAIVoiceProvider(VoiceProvider): """OpenAI / OpenAI-compatible provider。""" @property def name(self) -> str: return "openai" @staticmethod def _resolve_provider_name() -> str: provider = settings.AI_VOICE_PROVIDER or "openai" return provider.strip().lower() def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]: provider = self._resolve_provider_name() api_key = settings.AI_VOICE_API_KEY base_url = settings.AI_VOICE_BASE_URL if ( not api_key and provider == "openai" and (settings.LLM_PROVIDER or "").strip().lower() == "openai" ): api_key = settings.LLM_API_KEY base_url = base_url or settings.LLM_BASE_URL return api_key, base_url def _get_client(self, mode: str): from openai import OpenAI api_key, base_url = self._resolve_credentials() if not api_key: raise ValueError(f"{mode.upper()} provider 未配置 API Key") return OpenAI(api_key=api_key, base_url=base_url, max_retries=3) def is_available_for_stt(self) -> bool: api_key, _ = self._resolve_credentials() return bool(api_key) def is_available_for_tts(self) -> bool: api_key, _ = self._resolve_credentials() return bool(api_key) def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: if not content: return None if len(content) > self.MAX_TRANSCRIBE_BYTES: raise ValueError("语音文件超过 10MB,无法识别") try: client = self._get_client("stt") audio_file = BytesIO(content) audio_file.name = filename response = client.audio.transcriptions.create( model=settings.AI_VOICE_STT_MODEL, file=audio_file, language=settings.AI_VOICE_LANGUAGE or "zh", response_format="verbose_json", ) text = getattr(response, "text", None) return text.strip() if text else None except Exception as err: logger.error(f"语音转文字失败: provider={self.name}, error={err}") return None def synthesize_speech(self, text: str) -> Optional[Path]: if not text: return None try: client = self._get_client("tts") voice_dir = settings.TEMP_PATH / "voice" voice_dir.mkdir(parents=True, exist_ok=True) output_path = voice_dir / f"{uuid4().hex}.opus" response = client.audio.speech.create( model=settings.AI_VOICE_TTS_MODEL, voice=settings.AI_VOICE_TTS_VOICE, input=text, response_format="opus", ) response.write_to_file(output_path) return output_path except Exception as err: logger.error(f"文字转语音失败: provider={self.name}, error={err}") return None class VoiceHelper: """统一语音入口,负责音频能力判断与 STT/TTS provider 路由。""" _providers: Dict[str, VoiceProvider] = { "openai": OpenAIVoiceProvider(), } REPLY_MODE_NATIVE = "native_voice" REPLY_MODE_TEXT = "text" @classmethod def register_provider(cls, provider: VoiceProvider) -> None: cls._providers[provider.name.lower()] = provider @staticmethod def is_enabled() -> bool: """音频输入输出总开关,以显式配置为准。""" return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT) @staticmethod def _resolve_provider_name() -> str: """标准化当前配置的语音 provider 名称。""" provider = settings.AI_VOICE_PROVIDER or "openai" return provider.strip().lower() @classmethod def get_provider(cls, mode: str) -> Optional[VoiceProvider]: provider_name = cls._resolve_provider_name() provider = cls._providers.get(provider_name) if provider: return provider logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}") return None @classmethod def get_registered_providers(cls) -> list[str]: return sorted(cls._providers.keys()) @classmethod def is_available(cls, mode: Optional[str] = None) -> bool: if not cls.is_enabled(): return False if mode: provider = cls.get_provider(mode) if not provider: return False return ( provider.is_available_for_stt() if mode.lower() == "stt" else provider.is_available_for_tts() ) return cls.is_available("stt") or cls.is_available("tts") @classmethod def supports_native_voice_reply( cls, channel: Optional[str], source: Optional[str] ) -> bool: """ 判断当前渠道是否支持原生语音消息发送。 """ if not channel: return False from app.helper.service import ServiceConfigHelper from app.schemas.types import MessageChannel try: channel_enum = MessageChannel(channel) except (TypeError, ValueError): return False if channel_enum == MessageChannel.Telegram: return True if channel_enum != MessageChannel.Wechat: return False # 企业微信 bot 模式不支持发送语音,只有应用模式可用。 for config in ServiceConfigHelper.get_notification_configs(): if config.name != source: continue return (config.config or {}).get("WECHAT_MODE", "app") != "bot" return False @classmethod def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str: """ 仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。 """ if cls.supports_native_voice_reply(channel=channel, source=source): return cls.REPLY_MODE_NATIVE return cls.REPLY_MODE_TEXT @classmethod def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]: if not cls.is_enabled(): return None provider = cls.get_provider("stt") if not provider: return None return provider.transcribe_bytes(content=content, filename=filename) @classmethod def synthesize_speech(cls, text: str) -> Optional[Path]: if not cls.is_enabled(): return None provider = cls.get_provider("tts") if not provider: return None return provider.synthesize_speech(text=text)