mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
Merge pull request #16 from ccp-p/master
feat(transcriber): 添加 Bcut 和 Kuaishou 转录器实现,优化转录服务提供器
This commit is contained in:
250
backend/app/transcriber/bcut.py
Normal file
250
backend/app/transcriber/bcut.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
import requests
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
__version__ = "0.0.3"
|
||||
|
||||
API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface"
|
||||
|
||||
# 申请上传
|
||||
API_REQ_UPLOAD = API_BASE_URL + "/resource/create"
|
||||
|
||||
# 提交上传
|
||||
API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete"
|
||||
|
||||
# 创建任务
|
||||
API_CREATE_TASK = API_BASE_URL + "/task"
|
||||
|
||||
# 查询结果
|
||||
API_QUERY_RESULT = API_BASE_URL + "/task/result"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class BcutTranscriber(Transcriber):
|
||||
"""必剪 语音识别接口"""
|
||||
headers = {
|
||||
'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.task_id = None
|
||||
self.__etags = []
|
||||
|
||||
self.__in_boss_key: Optional[str] = None
|
||||
self.__resource_id: Optional[str] = None
|
||||
self.__upload_id: Optional[str] = None
|
||||
self.__upload_urls: List[str] = []
|
||||
self.__per_size: Optional[int] = None
|
||||
self.__clips: Optional[int] = None
|
||||
|
||||
self.__etags: List[str] = []
|
||||
self.__download_url: Optional[str] = None
|
||||
self.task_id: Optional[str] = None
|
||||
|
||||
def _load_file(self, file_path: str) -> bytes:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def _upload(self, file_path: str) -> None:
|
||||
"""申请上传"""
|
||||
file_binary = self._load_file(file_path)
|
||||
if not file_binary:
|
||||
raise ValueError("无法读取文件数据")
|
||||
|
||||
payload = json.dumps({
|
||||
"type": 2,
|
||||
"name": "audio.mp3",
|
||||
"size": len(file_binary),
|
||||
"ResourceFileType": "mp3",
|
||||
"model_id": "8",
|
||||
})
|
||||
|
||||
resp = self.session.post(
|
||||
API_REQ_UPLOAD,
|
||||
data=payload,
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
resp_data = resp["data"]
|
||||
|
||||
self.__in_boss_key = resp_data["in_boss_key"]
|
||||
self.__resource_id = resp_data["resource_id"]
|
||||
self.__upload_id = resp_data["upload_id"]
|
||||
self.__upload_urls = resp_data["upload_urls"]
|
||||
self.__per_size = resp_data["per_size"]
|
||||
self.__clips = len(resp_data["upload_urls"])
|
||||
|
||||
logger.info(
|
||||
f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}"
|
||||
)
|
||||
self.__upload_part(file_binary)
|
||||
self.__commit_upload()
|
||||
|
||||
def __upload_part(self, file_binary: bytes) -> None:
|
||||
"""上传音频数据"""
|
||||
for clip in range(self.__clips):
|
||||
start_range = clip * self.__per_size
|
||||
end_range = min((clip + 1) * self.__per_size, len(file_binary))
|
||||
logger.info(f"开始上传分片{clip}: {start_range}-{end_range}")
|
||||
resp = self.session.put(
|
||||
self.__upload_urls[clip],
|
||||
data=file_binary[start_range:end_range],
|
||||
headers={'Content-Type': 'application/octet-stream'}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
etag = resp.headers.get("Etag", "").strip('"')
|
||||
self.__etags.append(etag)
|
||||
logger.info(f"分片{clip}上传成功: {etag}")
|
||||
|
||||
def __commit_upload(self) -> None:
|
||||
"""提交上传数据"""
|
||||
data = json.dumps({
|
||||
"InBossKey": self.__in_boss_key,
|
||||
"ResourceId": self.__resource_id,
|
||||
"Etags": ",".join(self.__etags),
|
||||
"UploadId": self.__upload_id,
|
||||
"model_id": "8",
|
||||
})
|
||||
resp = self.session.post(
|
||||
API_COMMIT_UPLOAD,
|
||||
data=data,
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"上传提交失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
self.__download_url = resp["data"]["download_url"]
|
||||
logger.info(f"提交成功,下载链接: {self.__download_url}")
|
||||
|
||||
def _create_task(self) -> str:
|
||||
"""开始创建转换任务"""
|
||||
resp = self.session.post(
|
||||
API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"创建任务失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
self.task_id = resp["data"]["task_id"]
|
||||
logger.info(f"任务已创建: {self.task_id}")
|
||||
return self.task_id
|
||||
|
||||
def _query_result(self) -> dict:
|
||||
"""查询转换结果"""
|
||||
resp = self.session.get(
|
||||
API_QUERY_RESULT,
|
||||
params={"model_id": 7, "task_id": self.task_id},
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"查询结果失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return resp["data"]
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
"""执行识别过程,符合 Transcriber 接口"""
|
||||
try:
|
||||
logger.info(f"开始处理文件: {file_path}")
|
||||
|
||||
# 上传文件
|
||||
logger.info("正在上传文件...")
|
||||
self._upload(file_path)
|
||||
|
||||
# 创建任务
|
||||
logger.info("提交转录任务...")
|
||||
self._create_task()
|
||||
|
||||
# 轮询检查任务状态
|
||||
logger.info("等待转录结果...")
|
||||
task_resp = None
|
||||
max_retries = 500
|
||||
for i in range(max_retries):
|
||||
task_resp = self._query_result()
|
||||
|
||||
if task_resp["state"] == 4: # 完成状态
|
||||
break
|
||||
elif task_resp["state"] == 3: # 失败状态
|
||||
error_msg = f"B站ASR任务失败,状态码: {task_resp['state']}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 每隔一段时间打印进度
|
||||
if i % 10 == 0:
|
||||
logger.info(f"转录进行中... {i}/{max_retries}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if not task_resp or task_resp["state"] != 4:
|
||||
error_msg = f"B站ASR任务未能完成,状态: {task_resp.get('state') if task_resp else 'Unknown'}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 解析结果
|
||||
logger.info("转录成功,处理结果...")
|
||||
result_json = json.loads(task_resp["result"])
|
||||
|
||||
# 提取分段数据
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
for u in result_json.get("utterances", []):
|
||||
text = u.get("transcript", "").strip()
|
||||
# B站ASR返回的时间戳是毫秒,需要转换为秒
|
||||
start_time = float(u.get("start_time", 0)) / 1000.0
|
||||
end_time = float(u.get("end_time", 0)) / 1000.0
|
||||
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=text
|
||||
))
|
||||
|
||||
# 创建结果对象
|
||||
result = TranscriptResult(
|
||||
language=result_json.get("language", "zh"),
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result_json
|
||||
)
|
||||
|
||||
# 触发完成事件
|
||||
self.on_finish(file_path, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"B站ASR处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
"""转录完成的回调"""
|
||||
logger.info(f"B站ASR转写完成: {video_path}")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
115
backend/app/transcriber/kuaishou.py
Normal file
115
backend/app/transcriber/kuaishou.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import requests
|
||||
import logging
|
||||
import os
|
||||
from typing import Union, List, Dict, Optional
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class KuaishouTranscriber(Transcriber):
|
||||
"""快手语音识别实现"""
|
||||
|
||||
API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _load_file(self, file_path: str) -> bytes:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def _submit(self, file_path: str) -> dict:
|
||||
"""提交识别请求"""
|
||||
try:
|
||||
file_binary = self._load_file(file_path)
|
||||
|
||||
payload = {
|
||||
"typeId": "1"
|
||||
}
|
||||
|
||||
# 使用文件名作为上传文件名
|
||||
file_name = os.path.basename(file_path)
|
||||
files = [('file', (file_name, file_binary, 'audio/mpeg'))]
|
||||
|
||||
logger.info(f"开始向快手API提交请求,文件: {file_name}")
|
||||
response = requests.post(self.API_URL, data=payload, files=files, timeout=300)
|
||||
response.raise_for_status() # 检查HTTP错误
|
||||
|
||||
result = response.json()
|
||||
|
||||
# 检查快手API返回是否包含错误
|
||||
if "data" not in result or result.get("code", 0) != 0:
|
||||
error_msg = f"快手API返回错误: {result.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"快手ASR请求网络错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"快手ASR请求处理错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
"""执行转录过程,符合 Transcriber 接口"""
|
||||
try:
|
||||
logger.info(f"开始处理文件: {file_path}")
|
||||
|
||||
# 提交请求并获取结果
|
||||
logger.info("向快手API提交识别请求...")
|
||||
result_data = self._submit(file_path)
|
||||
|
||||
logger.info("请求成功,处理结果...")
|
||||
|
||||
# 提取分段数据
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
# 解析快手API返回的文本段
|
||||
texts = result_data.get('data', {}).get('text', [])
|
||||
for u in texts:
|
||||
text = u.get('text', '').strip()
|
||||
start_time = float(u.get('start_time', 0))
|
||||
end_time = float(u.get('end_time', 0))
|
||||
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=text
|
||||
))
|
||||
|
||||
# 创建结果对象
|
||||
result = TranscriptResult(
|
||||
language="zh", # 快手API可能不返回语言信息,默认为中文
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result_data
|
||||
)
|
||||
|
||||
# 触发完成事件
|
||||
self.on_finish(file_path, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"快手ASR处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
"""转录完成的回调"""
|
||||
logger.info(f"快手ASR转写完成: {video_path}")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
@@ -1,19 +1,74 @@
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
from app.transcriber.bcut import BcutTranscriber
|
||||
from app.transcriber.kuaishou import KuaishouTranscriber
|
||||
from app.utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info('实例化transcriber')
|
||||
# TODO:后面需要加入逻辑选择
|
||||
_transcriber = None
|
||||
logger.info('初始化转录服务提供器')
|
||||
|
||||
def get_transcriber(model_size="base", device="cuda"):
|
||||
global _transcriber
|
||||
# 维护各种转录器的单例实例
|
||||
_transcribers = {
|
||||
'whisper': None,
|
||||
'bcut': None,
|
||||
'kuaishou': None
|
||||
}
|
||||
|
||||
if _transcriber is None:
|
||||
logger.info('不存在 transcriber ,开始实例化transcriber。')
|
||||
def get_whisper_transcriber(model_size="base", device="cuda"):
|
||||
"""获取 Whisper 转录器实例"""
|
||||
if _transcribers['whisper'] is None:
|
||||
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
|
||||
try:
|
||||
_transcriber = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info(f'实例化transcriber成功。参数:{model_size}, {device} ')
|
||||
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info('Whisper 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"实例化transcriber失败,请检查是否安装whisper。{e}")
|
||||
return _transcriber
|
||||
logger.error(f"Whisper 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['whisper']
|
||||
|
||||
def get_bcut_transcriber():
|
||||
"""获取 Bcut 转录器实例"""
|
||||
if _transcribers['bcut'] is None:
|
||||
logger.info('创建 Bcut 转录器实例')
|
||||
try:
|
||||
_transcribers['bcut'] = BcutTranscriber()
|
||||
logger.info('Bcut 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"Bcut 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['bcut']
|
||||
|
||||
def get_kuaishou_transcriber():
|
||||
"""获取快手转录器实例"""
|
||||
if _transcribers['kuaishou'] is None:
|
||||
logger.info('创建快手转录器实例')
|
||||
try:
|
||||
_transcribers['kuaishou'] = KuaishouTranscriber()
|
||||
logger.info('快手转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"快手转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['kuaishou']
|
||||
|
||||
def get_transcriber(transcriber_type="whisper", model_size="base", device="cuda"):
|
||||
"""
|
||||
获取指定类型的转录器实例
|
||||
|
||||
参数:
|
||||
transcriber_type: 转录器类型,支持 "whisper", "bcut", "kuaishou"
|
||||
model_size: 模型大小,whisper 特有参数
|
||||
device: 设备类型,whisper 特有参数
|
||||
|
||||
返回:
|
||||
对应类型的转录器实例
|
||||
"""
|
||||
logger.info(f'获取转录器,类型: {transcriber_type}')
|
||||
|
||||
if transcriber_type == "whisper":
|
||||
return get_whisper_transcriber(model_size, device)
|
||||
elif transcriber_type == "bcut":
|
||||
return get_bcut_transcriber()
|
||||
elif transcriber_type == "kuaishou":
|
||||
return get_kuaishou_transcriber()
|
||||
else:
|
||||
logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper')
|
||||
return get_whisper_transcriber(model_size, device)
|
||||
Reference in New Issue
Block a user