feat: 新增模型管理和供应商配置功能

### v1.1.0
- #### Added
  - 新增 AI 笔记风格选择
  - 新增 AI 笔记返回格式选择
  - 添加 AI 自定义笔记备注 Prompt
  - 添加任务失败重试
  - 添加全局设置页,可在设置页进行模型设置

- #### Optimize
  - 优化前端样式,优化用户体验
  - 增加生成中间产物,可用于失败后加快生成速度
- #### Fix
  - 修复视频截图视频过早删除错误
This commit is contained in:
思诺特
2025-04-26 23:40:17 +08:00
parent 1323cfd1ec
commit 171dea5e0d
51 changed files with 2511 additions and 414 deletions

View File

@@ -0,0 +1,251 @@
import json
import logging
import time
from typing import Optional, List, Dict, Union
import requests
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from events import transcription_finished
__version__ = "0.0.3"
API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface"
# 申请上传
API_REQ_UPLOAD = API_BASE_URL + "/resource/create"
# 提交上传
API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete"
# 创建任务
API_CREATE_TASK = API_BASE_URL + "/task"
# 查询结果
API_QUERY_RESULT = API_BASE_URL + "/task/result"
logger = get_logger(__name__)
class BcutTranscriber(Transcriber):
"""必剪 语音识别接口"""
headers = {
'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)',
'Content-Type': 'application/json'
}
def __init__(self):
self.session = requests.Session()
self.task_id = None
self.__etags = []
self.__in_boss_key: Optional[str] = None
self.__resource_id: Optional[str] = None
self.__upload_id: Optional[str] = None
self.__upload_urls: List[str] = []
self.__per_size: Optional[int] = None
self.__clips: Optional[int] = None
self.__etags: List[str] = []
self.__download_url: Optional[str] = None
self.task_id: Optional[str] = None
def _load_file(self, file_path: str) -> bytes:
"""读取文件内容"""
with open(file_path, 'rb') as f:
return f.read()
def _upload(self, file_path: str) -> None:
"""申请上传"""
file_binary = self._load_file(file_path)
if not file_binary:
raise ValueError("无法读取文件数据")
payload = json.dumps({
"type": 2,
"name": "audio.mp3",
"size": len(file_binary),
"ResourceFileType": "mp3",
"model_id": "8",
})
resp = self.session.post(
API_REQ_UPLOAD,
data=payload,
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
resp_data = resp["data"]
self.__in_boss_key = resp_data["in_boss_key"]
self.__resource_id = resp_data["resource_id"]
self.__upload_id = resp_data["upload_id"]
self.__upload_urls = resp_data["upload_urls"]
self.__per_size = resp_data["per_size"]
self.__clips = len(resp_data["upload_urls"])
logger.info(
f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}"
)
self.__upload_part(file_binary)
self.__commit_upload()
def __upload_part(self, file_binary: bytes) -> None:
"""上传音频数据"""
for clip in range(self.__clips):
start_range = clip * self.__per_size
end_range = min((clip + 1) * self.__per_size, len(file_binary))
logger.info(f"开始上传分片{clip}: {start_range}-{end_range}")
resp = self.session.put(
self.__upload_urls[clip],
data=file_binary[start_range:end_range],
headers={'Content-Type': 'application/octet-stream'}
)
resp.raise_for_status()
etag = resp.headers.get("Etag", "").strip('"')
self.__etags.append(etag)
logger.info(f"分片{clip}上传成功: {etag}")
def __commit_upload(self) -> None:
"""提交上传数据"""
data = json.dumps({
"InBossKey": self.__in_boss_key,
"ResourceId": self.__resource_id,
"Etags": ",".join(self.__etags),
"UploadId": self.__upload_id,
"model_id": "8",
})
resp = self.session.post(
API_COMMIT_UPLOAD,
data=data,
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
print('Bili',resp)
if resp.get("code") != 0:
error_msg = f"上传提交失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
self.__download_url = resp["data"]["download_url"]
logger.info(f"提交成功,下载链接: {self.__download_url}")
def _create_task(self) -> str:
"""开始创建转换任务"""
resp = self.session.post(
API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
if resp.get("code") != 0:
error_msg = f"创建任务失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
self.task_id = resp["data"]["task_id"]
logger.info(f"任务已创建: {self.task_id}")
return self.task_id
def _query_result(self) -> dict:
"""查询转换结果"""
resp = self.session.get(
API_QUERY_RESULT,
params={"model_id": 7, "task_id": self.task_id},
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
if resp.get("code") != 0:
error_msg = f"查询结果失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
return resp["data"]
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
"""执行识别过程,符合 Transcriber 接口"""
try:
logger.info(f"开始处理文件: {file_path}")
# 上传文件
logger.info("正在上传文件...")
self._upload(file_path)
# 创建任务
logger.info("提交转录任务...")
self._create_task()
# 轮询检查任务状态
logger.info("等待转录结果...")
task_resp = None
max_retries = 500
for i in range(max_retries):
task_resp = self._query_result()
if task_resp["state"] == 4: # 完成状态
break
elif task_resp["state"] == 3: # 失败状态
error_msg = f"B站ASR任务失败状态码: {task_resp['state']}"
logger.error(error_msg)
raise Exception(error_msg)
# 每隔一段时间打印进度
if i % 10 == 0:
logger.info(f"转录进行中... {i}/{max_retries}")
time.sleep(1)
if not task_resp or task_resp["state"] != 4:
error_msg = f"B站ASR任务未能完成状态: {task_resp.get('state') if task_resp else 'Unknown'}"
logger.error(error_msg)
raise Exception(error_msg)
# 解析结果
logger.info("转录成功,处理结果...")
result_json = json.loads(task_resp["result"])
# 提取分段数据
segments = []
full_text = ""
for u in result_json.get("utterances", []):
text = u.get("transcript", "").strip()
# B站ASR返回的时间戳是毫秒需要转换为秒
start_time = float(u.get("start_time", 0)) / 1000.0
end_time = float(u.get("end_time", 0)) / 1000.0
full_text += text + " "
segments.append(TranscriptSegment(
start=start_time,
end=end_time,
text=text
))
# 创建结果对象
result = TranscriptResult(
language=result_json.get("language", "zh"),
full_text=full_text.strip(),
segments=segments,
raw=result_json
)
# 触发完成事件
# self.on_finish(file_path, result)
return result
except Exception as e:
logger.error(f"B站ASR处理失败: {str(e)}")
raise
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
"""转录完成的回调"""
logger.info(f"B站ASR转写完成: {video_path}")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -0,0 +1,115 @@
import requests
import logging
import os
from typing import Union, List, Dict, Optional
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from events import transcription_finished
logger = get_logger(__name__)
class KuaishouTranscriber(Transcriber):
"""快手语音识别实现"""
API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate"
def __init__(self):
pass
def _load_file(self, file_path: str) -> bytes:
"""读取文件内容"""
with open(file_path, 'rb') as f:
return f.read()
def _submit(self, file_path: str) -> dict:
"""提交识别请求"""
try:
file_binary = self._load_file(file_path)
payload = {
"typeId": "1"
}
# 使用文件名作为上传文件名
file_name = os.path.basename(file_path)
files = [('file', (file_name, file_binary, 'audio/mpeg'))]
logger.info(f"开始向快手API提交请求文件: {file_name}")
response = requests.post(self.API_URL, data=payload, files=files, timeout=300)
response.raise_for_status() # 检查HTTP错误
result = response.json()
print('result',result)
# 检查快手API返回是否包含错误
if "data" not in result or result.get("code", 0) != 0:
error_msg = f"快手API返回错误: {result.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
return result
except requests.exceptions.RequestException as e:
error_msg = f"快手ASR请求网络错误: {str(e)}"
logger.error(error_msg)
raise
except Exception as e:
error_msg = f"快手ASR请求处理错误: {str(e)}"
logger.error(error_msg)
raise
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
"""执行转录过程,符合 Transcriber 接口"""
try:
logger.info(f"开始处理文件: {file_path}")
# 提交请求并获取结果
logger.info("向快手API提交识别请求...")
result_data = self._submit(file_path)
logger.info("请求成功,处理结果...")
# 提取分段数据
segments = []
full_text = ""
# 解析快手API返回的文本段
texts = result_data.get('data', {}).get('text', [])
for u in texts:
text = u.get('text', '').strip()
start_time = float(u.get('start_time', 0))
end_time = float(u.get('end_time', 0))
full_text += text + " "
segments.append(TranscriptSegment(
start=start_time,
end=end_time,
text=text
))
# 创建结果对象
result = TranscriptResult(
language="zh", # 快手API可能不返回语言信息默认为中文
full_text=full_text.strip(),
segments=segments,
raw=result_data
)
# 触发完成事件
# self.on_finish(file_path, result)
return result
except Exception as e:
logger.error(f"快手ASR处理失败: {str(e)}")
raise
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
"""转录完成的回调"""
logger.info(f"快手ASR转写完成: {video_path}")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -31,15 +31,15 @@ _transcribers = {
def get_whisper_transcriber(model_size="base", device="cuda"):
"""获取 Whisper 转录器实例"""
if _transcribers['fast-whisper'] is None:
if _transcribers['fast-whisper'] is None:
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
try:
_transcribers['fast-whisper'] = WhisperTranscriber(model_size=model_size, device=device)
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
logger.info('Whisper 转录器创建成功')
except Exception as e:
logger.error(f"Whisper 转录器创建失败: {e}")
raise
return _transcribers['fast-whisper']
return _transcribers['whisper']
def get_bcut_transcriber():
"""获取 Bcut 转录器实例"""

View File

@@ -4,14 +4,19 @@ from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.env_checker import is_cuda_available, is_torch_installed
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
from pathlib import Path
import os
from tqdm import tqdm
from huggingface_hub import snapshot_download
'''
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
'''
logger=get_logger(__name__)
class WhisperTranscriber(Transcriber):
# TODO:修改为可配置
@@ -31,15 +36,25 @@ class WhisperTranscriber(Transcriber):
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
model_path = get_model_dir("whisper")
model_dir = get_model_dir("whisper")
model_path = os.path.join(model_dir, f"whisper-{model_size}")
if not Path(model_path).exists():
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
repo_id = f"guillaumekln/faster-whisper-{model_size}"
snapshot_download(
repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
self.model = WhisperModel(
model_size,
device=self.device,
# compute_type="int8", # 或 "float16"
compute_type=self.compute_type,
cpu_threads=cpu_threads,
download_root=model_path
download_root=model_dir
)
@staticmethod
def is_torch_installed() -> bool:
try:
@@ -88,7 +103,7 @@ class WhisperTranscriber(Transcriber):
segments=segments,
raw=info
)
self.on_finish(file_path, result)
# self.on_finish(file_path, result)
return result
except Exception as e:
print(f"转写失败:{e}")