优化 OpenAIVoiceProvider 逻辑,简化凭证与 provider 解析方法并调整最大转录文件大小限制

This commit is contained in:
jxxghp
2026-04-29 18:32:12 +08:00
parent b7749c44fd
commit 38c48fa4ce

View File

@@ -13,7 +13,7 @@ from app.log import logger
class VoiceProvider(ABC):
"""语音 provider 抽象层。"""
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
@property
@abstractmethod
@@ -49,10 +49,8 @@ class OpenAIVoiceProvider(VoiceProvider):
provider = settings.AI_VOICE_PROVIDER or "openai"
return provider.strip().lower()
@staticmethod
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
mode = mode.lower()
provider = OpenAIVoiceProvider._resolve_provider_name()
def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]:
provider = self._resolve_provider_name()
api_key = settings.AI_VOICE_API_KEY
base_url = settings.AI_VOICE_BASE_URL
@@ -69,17 +67,17 @@ class OpenAIVoiceProvider(VoiceProvider):
def _get_client(self, mode: str):
from openai import OpenAI
api_key, base_url = self._resolve_credentials(mode)
api_key, base_url = self._resolve_credentials()
if not api_key:
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
def is_available_for_stt(self) -> bool:
api_key, _ = self._resolve_credentials("stt")
api_key, _ = self._resolve_credentials()
return bool(api_key)
def is_available_for_tts(self) -> bool:
api_key, _ = self._resolve_credentials("tts")
api_key, _ = self._resolve_credentials()
return bool(api_key)
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
@@ -144,14 +142,12 @@ class VoiceHelper:
"""音频输入输出总开关,以显式配置为准。"""
return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT)
@staticmethod
def _resolve_provider_name(mode: str) -> str:
del mode
return OpenAIVoiceProvider._resolve_provider_name()
def _resolve_provider_name(self) -> str:
return self._resolve_provider_name()
@classmethod
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
provider_name = cls._resolve_provider_name(mode)
provider_name = cls._resolve_provider_name()
provider = cls._providers.get(provider_name)
if provider:
return provider