From 2466c7e1e99931c542da8c6835a9e75f326f2320 Mon Sep 17 00:00:00 2001 From: chenchengpeng Date: Wed, 16 Apr 2025 14:28:26 +0800 Subject: [PATCH] =?UTF-8?q?feat(transcriber):=20=E6=B7=BB=E5=8A=A0=20Bcut?= =?UTF-8?q?=20=E5=92=8C=20Kuaishou=20=E8=BD=AC=E5=BD=95=E5=99=A8=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=EF=BC=8C=E4=BC=98=E5=8C=96=E8=BD=AC=E5=BD=95=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E6=8F=90=E4=BE=9B=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/transcriber/bcut.py | 250 ++++++++++++++++++ backend/app/transcriber/kuaishou.py | 115 ++++++++ .../app/transcriber/transcriber_provider.py | 77 +++++- 3 files changed, 431 insertions(+), 11 deletions(-) create mode 100644 backend/app/transcriber/bcut.py create mode 100644 backend/app/transcriber/kuaishou.py diff --git a/backend/app/transcriber/bcut.py b/backend/app/transcriber/bcut.py new file mode 100644 index 0000000..e1f5c92 --- /dev/null +++ b/backend/app/transcriber/bcut.py @@ -0,0 +1,250 @@ +import json +import logging +import time +from typing import Optional, List, Dict, Union + +import requests + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from events import transcription_finished + +__version__ = "0.0.3" + +API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface" + +# 申请上传 +API_REQ_UPLOAD = API_BASE_URL + "/resource/create" + +# 提交上传 +API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete" + +# 创建任务 +API_CREATE_TASK = API_BASE_URL + "/task" + +# 查询结果 +API_QUERY_RESULT = API_BASE_URL + "/task/result" + +logger = get_logger(__name__) + +class BcutTranscriber(Transcriber): + """必剪 语音识别接口""" + headers = { + 'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)', + 'Content-Type': 'application/json' + } + + def __init__(self): + self.session = requests.Session() + self.task_id = None + self.__etags = [] + + self.__in_boss_key: Optional[str] = None + self.__resource_id: Optional[str] = None + self.__upload_id: Optional[str] = None + self.__upload_urls: List[str] = [] + self.__per_size: Optional[int] = None + self.__clips: Optional[int] = None + + self.__etags: List[str] = [] + self.__download_url: Optional[str] = None + self.task_id: Optional[str] = None + + def _load_file(self, file_path: str) -> bytes: + """读取文件内容""" + with open(file_path, 'rb') as f: + return f.read() + + def _upload(self, file_path: str) -> None: + """申请上传""" + file_binary = self._load_file(file_path) + if not file_binary: + raise ValueError("无法读取文件数据") + + payload = json.dumps({ + "type": 2, + "name": "audio.mp3", + "size": len(file_binary), + "ResourceFileType": "mp3", + "model_id": "8", + }) + + resp = self.session.post( + API_REQ_UPLOAD, + data=payload, + headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + resp_data = resp["data"] + + self.__in_boss_key = resp_data["in_boss_key"] + self.__resource_id = resp_data["resource_id"] + self.__upload_id = resp_data["upload_id"] + self.__upload_urls = resp_data["upload_urls"] + self.__per_size = resp_data["per_size"] + self.__clips = len(resp_data["upload_urls"]) + + logger.info( + f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}" + ) + self.__upload_part(file_binary) + self.__commit_upload() + + def __upload_part(self, file_binary: bytes) -> None: + """上传音频数据""" + for clip in range(self.__clips): + start_range = clip * self.__per_size + end_range = min((clip + 1) * self.__per_size, len(file_binary)) + logger.info(f"开始上传分片{clip}: {start_range}-{end_range}") + resp = self.session.put( + self.__upload_urls[clip], + data=file_binary[start_range:end_range], + headers={'Content-Type': 'application/octet-stream'} + ) + resp.raise_for_status() + etag = resp.headers.get("Etag", "").strip('"') + self.__etags.append(etag) + logger.info(f"分片{clip}上传成功: {etag}") + + def __commit_upload(self) -> None: + """提交上传数据""" + data = json.dumps({ + "InBossKey": self.__in_boss_key, + "ResourceId": self.__resource_id, + "Etags": ",".join(self.__etags), + "UploadId": self.__upload_id, + "model_id": "8", + }) + resp = self.session.post( + API_COMMIT_UPLOAD, + data=data, + headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + if resp.get("code") != 0: + error_msg = f"上传提交失败: {resp.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + self.__download_url = resp["data"]["download_url"] + logger.info(f"提交成功,下载链接: {self.__download_url}") + + def _create_task(self) -> str: + """开始创建转换任务""" + resp = self.session.post( + API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + if resp.get("code") != 0: + error_msg = f"创建任务失败: {resp.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + self.task_id = resp["data"]["task_id"] + logger.info(f"任务已创建: {self.task_id}") + return self.task_id + + def _query_result(self) -> dict: + """查询转换结果""" + resp = self.session.get( + API_QUERY_RESULT, + params={"model_id": 7, "task_id": self.task_id}, + headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + if resp.get("code") != 0: + error_msg = f"查询结果失败: {resp.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + return resp["data"] + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + """执行识别过程,符合 Transcriber 接口""" + try: + logger.info(f"开始处理文件: {file_path}") + + # 上传文件 + logger.info("正在上传文件...") + self._upload(file_path) + + # 创建任务 + logger.info("提交转录任务...") + self._create_task() + + # 轮询检查任务状态 + logger.info("等待转录结果...") + task_resp = None + max_retries = 500 + for i in range(max_retries): + task_resp = self._query_result() + + if task_resp["state"] == 4: # 完成状态 + break + elif task_resp["state"] == 3: # 失败状态 + error_msg = f"B站ASR任务失败,状态码: {task_resp['state']}" + logger.error(error_msg) + raise Exception(error_msg) + + # 每隔一段时间打印进度 + if i % 10 == 0: + logger.info(f"转录进行中... {i}/{max_retries}") + + time.sleep(1) + + if not task_resp or task_resp["state"] != 4: + error_msg = f"B站ASR任务未能完成,状态: {task_resp.get('state') if task_resp else 'Unknown'}" + logger.error(error_msg) + raise Exception(error_msg) + + # 解析结果 + logger.info("转录成功,处理结果...") + result_json = json.loads(task_resp["result"]) + + # 提取分段数据 + segments = [] + full_text = "" + + for u in result_json.get("utterances", []): + text = u.get("transcript", "").strip() + # B站ASR返回的时间戳是毫秒,需要转换为秒 + start_time = float(u.get("start_time", 0)) / 1000.0 + end_time = float(u.get("end_time", 0)) / 1000.0 + + full_text += text + " " + segments.append(TranscriptSegment( + start=start_time, + end=end_time, + text=text + )) + + # 创建结果对象 + result = TranscriptResult( + language=result_json.get("language", "zh"), + full_text=full_text.strip(), + segments=segments, + raw=result_json + ) + + # 触发完成事件 + self.on_finish(file_path, result) + + return result + + except Exception as e: + logger.error(f"B站ASR处理失败: {str(e)}") + raise + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + """转录完成的回调""" + logger.info(f"B站ASR转写完成: {video_path}") + transcription_finished.send({ + "file_path": video_path, + }) \ No newline at end of file diff --git a/backend/app/transcriber/kuaishou.py b/backend/app/transcriber/kuaishou.py new file mode 100644 index 0000000..4d7c4b4 --- /dev/null +++ b/backend/app/transcriber/kuaishou.py @@ -0,0 +1,115 @@ +import requests +import logging +import os +from typing import Union, List, Dict, Optional + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from events import transcription_finished + +logger = get_logger(__name__) + +class KuaishouTranscriber(Transcriber): + """快手语音识别实现""" + + API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate" + + def __init__(self): + pass + + def _load_file(self, file_path: str) -> bytes: + """读取文件内容""" + with open(file_path, 'rb') as f: + return f.read() + + def _submit(self, file_path: str) -> dict: + """提交识别请求""" + try: + file_binary = self._load_file(file_path) + + payload = { + "typeId": "1" + } + + # 使用文件名作为上传文件名 + file_name = os.path.basename(file_path) + files = [('file', (file_name, file_binary, 'audio/mpeg'))] + + logger.info(f"开始向快手API提交请求,文件: {file_name}") + response = requests.post(self.API_URL, data=payload, files=files, timeout=300) + response.raise_for_status() # 检查HTTP错误 + + result = response.json() + + # 检查快手API返回是否包含错误 + if "data" not in result or result.get("code", 0) != 0: + error_msg = f"快手API返回错误: {result.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + return result + + except requests.exceptions.RequestException as e: + error_msg = f"快手ASR请求网络错误: {str(e)}" + logger.error(error_msg) + raise + except Exception as e: + error_msg = f"快手ASR请求处理错误: {str(e)}" + logger.error(error_msg) + raise + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + """执行转录过程,符合 Transcriber 接口""" + try: + logger.info(f"开始处理文件: {file_path}") + + # 提交请求并获取结果 + logger.info("向快手API提交识别请求...") + result_data = self._submit(file_path) + + logger.info("请求成功,处理结果...") + + # 提取分段数据 + segments = [] + full_text = "" + + # 解析快手API返回的文本段 + texts = result_data.get('data', {}).get('text', []) + for u in texts: + text = u.get('text', '').strip() + start_time = float(u.get('start_time', 0)) + end_time = float(u.get('end_time', 0)) + + full_text += text + " " + segments.append(TranscriptSegment( + start=start_time, + end=end_time, + text=text + )) + + # 创建结果对象 + result = TranscriptResult( + language="zh", # 快手API可能不返回语言信息,默认为中文 + full_text=full_text.strip(), + segments=segments, + raw=result_data + ) + + # 触发完成事件 + self.on_finish(file_path, result) + + return result + + except Exception as e: + logger.error(f"快手ASR处理失败: {str(e)}") + raise + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + """转录完成的回调""" + logger.info(f"快手ASR转写完成: {video_path}") + transcription_finished.send({ + "file_path": video_path, + }) \ No newline at end of file diff --git a/backend/app/transcriber/transcriber_provider.py b/backend/app/transcriber/transcriber_provider.py index b074d3e..e621612 100644 --- a/backend/app/transcriber/transcriber_provider.py +++ b/backend/app/transcriber/transcriber_provider.py @@ -1,19 +1,74 @@ from app.transcriber.whisper import WhisperTranscriber +from app.transcriber.bcut import BcutTranscriber +from app.transcriber.kuaishou import KuaishouTranscriber from app.utils.logger import get_logger logger = get_logger(__name__) -logger.info('实例化transcriber') -# TODO:后面需要加入逻辑选择 -_transcriber = None +logger.info('初始化转录服务提供器') -def get_transcriber(model_size="base", device="cuda"): - global _transcriber +# 维护各种转录器的单例实例 +_transcribers = { + 'whisper': None, + 'bcut': None, + 'kuaishou': None +} - if _transcriber is None: - logger.info('不存在 transcriber ,开始实例化transcriber。') +def get_whisper_transcriber(model_size="base", device="cuda"): + """获取 Whisper 转录器实例""" + if _transcribers['whisper'] is None: + logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}') try: - _transcriber = WhisperTranscriber(model_size=model_size, device=device) - logger.info(f'实例化transcriber成功。参数:{model_size}, {device} ') + _transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device) + logger.info('Whisper 转录器创建成功') except Exception as e: - logger.error(f"实例化transcriber失败,请检查是否安装whisper。{e}") - return _transcriber \ No newline at end of file + logger.error(f"Whisper 转录器创建失败: {e}") + raise + return _transcribers['whisper'] + +def get_bcut_transcriber(): + """获取 Bcut 转录器实例""" + if _transcribers['bcut'] is None: + logger.info('创建 Bcut 转录器实例') + try: + _transcribers['bcut'] = BcutTranscriber() + logger.info('Bcut 转录器创建成功') + except Exception as e: + logger.error(f"Bcut 转录器创建失败: {e}") + raise + return _transcribers['bcut'] + +def get_kuaishou_transcriber(): + """获取快手转录器实例""" + if _transcribers['kuaishou'] is None: + logger.info('创建快手转录器实例') + try: + _transcribers['kuaishou'] = KuaishouTranscriber() + logger.info('快手转录器创建成功') + except Exception as e: + logger.error(f"快手转录器创建失败: {e}") + raise + return _transcribers['kuaishou'] + +def get_transcriber(transcriber_type="whisper", model_size="base", device="cuda"): + """ + 获取指定类型的转录器实例 + + 参数: + transcriber_type: 转录器类型,支持 "whisper", "bcut", "kuaishou" + model_size: 模型大小,whisper 特有参数 + device: 设备类型,whisper 特有参数 + + 返回: + 对应类型的转录器实例 + """ + logger.info(f'获取转录器,类型: {transcriber_type}') + + if transcriber_type == "whisper": + return get_whisper_transcriber(model_size, device) + elif transcriber_type == "bcut": + return get_bcut_transcriber() + elif transcriber_type == "kuaishou": + return get_kuaishou_transcriber() + else: + logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper') + return get_whisper_transcriber(model_size, device) \ No newline at end of file