diff --git a/.env.example b/.env.example index aabd8ac..fe158f5 100644 --- a/.env.example +++ b/.env.example @@ -25,4 +25,8 @@ QWEN_API_BASE_URL= QWEN_MODEL= MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen # FFMPEG 配置 -FFMPEG_BIN_PATH= \ No newline at end of file +FFMPEG_BIN_PATH= + +# transcriber 相关配置 +TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou +WHISPER_MODEL_SIZE=base \ No newline at end of file diff --git a/backend/.env.example b/backend/.env.example index c3ee99d..19e42dd 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -21,3 +21,7 @@ QWEN_API_KEY= QWEN_API_BASE_URL= QWEN_MODEL= MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen + +# transcriber 相关配置 +TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou +WHISPER_MODEL_SIZE=base \ No newline at end of file diff --git a/backend/app/services/note.py b/backend/app/services/note.py index 5a3020d..eb85fd9 100644 --- a/backend/app/services/note.py +++ b/backend/app/services/note.py @@ -43,7 +43,7 @@ class NoteGenerator: def __init__(self): self.model_size: str = 'base' self.device: Union[str, None] = None - self.transcriber_type = 'fast-whisper' + self.transcriber_type = os.getenv('TRANSCRIBER_TYPE','fast-whisper') self.transcriber = self.get_transcriber() # TODO 需要更换为可调节 @@ -91,7 +91,7 @@ class NoteGenerator: ''' if self.transcriber_type == 'fast-whisper': logger.info("使用Whisper") - return get_transcriber() + return get_transcriber(transcriber_type='fast-whisper') else: logger.warning("不支持的转义器") raise ValueError(f"不支持的转义器:{self.transcriber}") diff --git a/backend/app/transcriber/transcriber_provider.py b/backend/app/transcriber/transcriber_provider.py index e621612..a59af19 100644 --- a/backend/app/transcriber/transcriber_provider.py +++ b/backend/app/transcriber/transcriber_provider.py @@ -1,3 +1,5 @@ +import os + from app.transcriber.whisper import WhisperTranscriber from app.transcriber.bcut import BcutTranscriber from app.transcriber.kuaishou import KuaishouTranscriber @@ -49,12 +51,12 @@ def get_kuaishou_transcriber(): raise return _transcribers['kuaishou'] -def get_transcriber(transcriber_type="whisper", model_size="base", device="cuda"): +def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"): """ 获取指定类型的转录器实例 参数: - transcriber_type: 转录器类型,支持 "whisper", "bcut", "kuaishou" + transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou" model_size: 模型大小,whisper 特有参数 device: 设备类型,whisper 特有参数 @@ -62,13 +64,14 @@ def get_transcriber(transcriber_type="whisper", model_size="base", device="cuda" 对应类型的转录器实例 """ logger.info(f'获取转录器,类型: {transcriber_type}') - - if transcriber_type == "whisper": - return get_whisper_transcriber(model_size, device) + if transcriber_type == "fast-whisper": + whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size) + return get_whisper_transcriber(whisper_model_size, device=device) elif transcriber_type == "bcut": return get_bcut_transcriber() elif transcriber_type == "kuaishou": return get_kuaishou_transcriber() else: logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper') - return get_whisper_transcriber(model_size, device) \ No newline at end of file + whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size) + return get_whisper_transcriber(whisper_model_size, device) \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index e98005a..3e96270 100644 --- a/backend/main.py +++ b/backend/main.py @@ -34,7 +34,7 @@ async def startup_event(): async def startup_event(): register_handler() ensure_ffmpeg_or_raise() - get_transcriber() + get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE","fast-whisper")) init_video_task_table() if __name__ == "__main__":