:feat 新增模型配置页面和相关功能

- 新增模型配置页面组件和路由
- 实现模型配置表单和相关逻辑- 添加全局配置入口和功能- 优化首页布局和样式- 新增 404 页面组件
- 更新部分组件样式和结构
This commit is contained in:
Jefferyhcool
2025-04-22 17:01:02 +08:00
parent 2aad103a77
commit bb974b0b89
95 changed files with 7723 additions and 1697 deletions

View File

@@ -1,8 +1,9 @@
from fastapi import FastAPI
from .routers import note
from .routers import note, provider
def create_app() -> FastAPI:
app = FastAPI(title="BiliNote")
app.include_router(note.router, prefix="/api")
app.include_router(provider.router, prefix="/api")
return app

View File

@@ -0,0 +1,131 @@
from app.db.sqlite_client import get_connection
from app.utils.logger import get_logger
logger = get_logger(__name__)
def init_provider_table():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS providers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
logo TEXT NOT NULL,
type TEXT NOT NULL, -- ✅ 新增字段
api_key TEXT NOT NULL,
base_url TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
try:
conn.commit()
conn.close()
logger.info("provider table created successfully.")
except Exception as e:
logger.error(f"Failed to create provider table: {e}")
def insert_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
INSERT INTO providers (name, api_key, base_url, logo, type)
VALUES (?, ?, ?, ?, ?)
""", (name, api_key, base_url, logo, type_))
try:
conn.commit()
cursor_id = cursor.lastrowid
conn.close()
logger.info(f"Provider inserted successfully. name: {name}, type: {type_}")
return cursor_id
except Exception as e:
logger.error(f"Failed to insert provider: {e}")
return None
def get_provider_by_name(name: str):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers WHERE name = ?", (name,))
try:
row = cursor.fetchone()
conn.close()
if row is None:
logger.info(f"Provider not found: {name}")
return None
logger.info(f"Provider found: {row}")
return row
except Exception as e:
logger.error(f"Failed to get provider by name: {e}")
def get_provider_by_id(id: int):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers WHERE id = ?", (id,))
try:
row = cursor.fetchone()
conn.close()
if row is None:
logger.info(f"Provider not found: {id}")
return None
logger.info(f"Provider found: {row}")
return row
except Exception as e:
logger.error(f"Failed to get provider by id: {e}")
def get_all_providers():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers")
try:
rows = cursor.fetchall()
conn.close()
if rows is None:
logger.info("No providers found")
return None
logger.info(f"Providers found: {rows}")
return rows
except Exception as e:
logger.error(f"Failed to get all providers: {e}")
def update_provider(id: int, name: str, api_key: str, base_url: str, logo: str, type_: str):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
UPDATE providers
SET name = ?, api_key = ?, base_url = ?, logo = ?, type = ?
WHERE id = ?
""", (name, api_key, base_url, logo, type_, id))
try:
conn.commit()
conn.close()
logger.info(f"Provider updated successfully. id: {id}, type: {type_}")
except Exception as e:
logger.error(f"Failed to update provider: {e}")
def delete_provider(id: int):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("DELETE FROM providers WHERE id = ?", (id,))
try:
conn.commit()
conn.close()
logger.info(f"Provider deleted successfully. id: {id}")
except Exception as e:
logger.error(f"Failed to delete provider: {e}")

View File

@@ -1,4 +1,4 @@
import sqlite3
def get_connection():
return sqlite3.connect("note_tasks.db")
return sqlite3.connect("bili_note.db")

View File

@@ -31,6 +31,13 @@ class BilibiliDownloader(Downloader, ABC):
ydl_opts = {
'format': 'bestaudio[ext=m4a]/bestaudio/best',
'outtmpl': output_path,
'postprocessors': [
{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '64',
}
],
'noplaylist': True,
'quiet': False,
}
@@ -41,7 +48,7 @@ class BilibiliDownloader(Downloader, ABC):
title = info.get("title")
duration = info.get("duration", 0)
cover_url = info.get("thumbnail")
audio_path = os.path.join(output_dir, f"{video_id}.m4a")
audio_path = os.path.join(output_dir, f"{video_id}.mp3")
return AudioDownloadResult(
file_path=audio_path,
@@ -69,7 +76,7 @@ class BilibiliDownloader(Downloader, ABC):
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
ydl_opts = {
'format': 'bv*+ba/bestvideo+bestaudio/best',
'format': 'bv*[ext=mp4]/bestvideo+bestaudio/best',
'outtmpl': output_path,
'noplaylist': True,
'quiet': False,

View File

@@ -27,6 +27,13 @@ class DouyinDownloader(Downloader, ABC):
ydl_opts = {
'format': 'bestaudio[ext=m4a]/bestaudio/best',
'outtmpl': output_path,
'postprocessors': [
{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '64',
}
],
'noplaylist': True,
'quiet': False,
}
@@ -37,7 +44,7 @@ class DouyinDownloader(Downloader, ABC):
title = info.get("title")
duration = info.get("duration", 0)
cover_url = info.get("thumbnail")
audio_path = os.path.join(output_dir, f"{video_id}.m4a")
audio_path = os.path.join(output_dir, f"{video_id}.mp3")
return AudioDownloadResult(
file_path=audio_path,

View File

@@ -42,7 +42,7 @@ class YoutubeDownloader(Downloader, ABC):
title = info.get("title")
duration = info.get("duration", 0)
cover_url = info.get("thumbnail")
audio_path = os.path.join(output_dir, f"{video_id}.m4a")
audio_path = os.path.join(output_dir, f"{video_id}.mp3")
return AudioDownloadResult(
file_path=audio_path,

View File

@@ -10,4 +10,8 @@ class GPT(ABC):
:param source:
:return:
'''
pass
def create_messages(self, segments:list,**kwargs)->list:
pass
def list_models(self):
pass

View File

@@ -0,0 +1,13 @@
from openai import OpenAI
from app.gpt.base import GPT
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.gpt.universal_gpt import UniversalGPT
from app.models.model_config import ModelConfig
class GPTFactory:
@staticmethod
def from_config(config: ModelConfig) -> GPT:
client = OpenAICompatibleProvider(api_key=config.api_key, base_url=config.base_url).get_client()
return UniversalGPT(client=client, model=config.model_name)

View File

@@ -2,6 +2,7 @@ from typing import List
from app.gpt.base import GPT
from openai import OpenAI
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.gpt.utils import fix_markdown
from app.models.gpt_model import GPTSource
from app.models.transcriber_model import TranscriptSegment
@@ -15,7 +16,7 @@ class OpenaiGPT(GPT):
self.base_url = getenv("OPENAI_API_BASE_URL")
self.model=getenv('OPENAI_MODEL')
print(self.model)
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
self.client = OpenAICompatibleProvider(api_key=self.api_key, base_url=self.base_url)
self.screenshot = False
self.link=False
@@ -49,17 +50,20 @@ class OpenaiGPT(GPT):
print(content)
return [{"role": "user", "content": content + AI_SUM}]
def list_models(self):
return self.client.list_models()
def summarize(self, source: GPTSource) -> str:
self.screenshot = source.screenshot
self.link = source.link
source.segment = self.ensure_segments_type(source.segment)
messages = self.create_messages(source.segment, source.title,source.tags)
response = self.client.chat.completions.create(
response = self.client.chat(
model=self.model,
messages=messages,
temperature=0.7
)
return response.choices[0].message.content.strip()
if __name__ == '__main__':
gpt = OpenaiGPT()
print(gpt.list_models())

View File

@@ -0,0 +1,22 @@
from typing import Optional, Union
from openai import OpenAI
class OpenAICompatibleProvider:
def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None):
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model = model
@property
def get_client(self):
return self.client
@staticmethod
def test_connection(api_key: str, base_url: str) -> bool:
try:
client = OpenAI(api_key=api_key, base_url=base_url)
client.models.list()
return True
except Exception as e:
print(f"Error connecting to OpenAI API: {e}")
return False

View File

@@ -2,6 +2,7 @@ from typing import List
from app.gpt.base import GPT
from openai import OpenAI
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.gpt.utils import fix_markdown
from app.models.gpt_model import GPTSource
from app.models.transcriber_model import TranscriptSegment
@@ -15,7 +16,7 @@ class QwenGPT(GPT):
self.base_url = getenv("QWEN_API_BASE_URL")
self.model=getenv('QWEN_MODEL')
print(self.model)
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
self.client = OpenAICompatibleProvider(api_key=self.api_key, base_url=self.base_url)
self.screenshot = False
def _format_time(self, seconds: float) -> str:
@@ -44,7 +45,8 @@ class QwenGPT(GPT):
content += SCREENSHOT
print(content)
return [{"role": "user", "content": content + AI_SUM}]
def list_models(self):
return self.client.list_models()
def summarize(self, source: GPTSource) -> str:
self.screenshot = source.screenshot
source.segment = self.ensure_segments_type(source.segment)
@@ -56,4 +58,6 @@ class QwenGPT(GPT):
)
return response.choices[0].message.content.strip()
if __name__ == '__main__':
gpt = QwenGPT()
print(gpt.list_models())

17
backend/app/gpt/test.py Normal file
View File

@@ -0,0 +1,17 @@
from app.models.model_config import ModelConfig
if __name__ == '__main__':
from app.gpt.gpt_factory import GPTFactory
# 构建模型config
config=ModelConfig(
id='asas',
api_key='',
base_url='',
model_name="gpt-4o",
provider='openai',
name='gpt-4o'
)
# 构建GPT
gpt=GPTFactory().from_config(config)

View File

@@ -0,0 +1,62 @@
from app.gpt.base import GPT
from app.models.gpt_model import GPTSource
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
from app.gpt.utils import fix_markdown
from app.models.transcriber_model import TranscriptSegment
from datetime import timedelta
from typing import List
class UniversalGPT(GPT):
def __init__(self, client, model: str, temperature: float = 0.7):
self.client = client
self.model = model
self.temperature = temperature
self.screenshot = False
self.screenshot = False
self.link = False
def _format_time(self, seconds: float) -> str:
return str(timedelta(seconds=int(seconds)))[2:]
def _build_segment_text(self, segments: List[TranscriptSegment]) -> str:
return "\n".join(
f"{self._format_time(seg.start)} - {seg.text.strip()}"
for seg in segments
)
def ensure_segments_type(self, segments) -> List[TranscriptSegment]:
return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments]
def create_messages(self, segments: List[TranscriptSegment],**kwargs):
content = BASE_PROMPT.format(
video_title=kwargs.get('title'),
segment_text=self._build_segment_text(segments),
tags=kwargs.get('tags')
)
if self.screenshot:
print(":需要截图")
content += SCREENSHOT
if self.link:
print(":需要链接")
content += LINK
print(content)
return [{"role": "user", "content": content + AI_SUM}]
def list_models(self):
return self.client.list_models()
def summarize(self, source: GPTSource) -> str:
self.screenshot = source.screenshot
self.link = source.link
source.segment = self.ensure_segments_type(source.segment)
messages = self.create_messages(source.segment, source.title,source.tags)
response = self.client.chat(
model=self.model,
messages=messages,
temperature=0.7
)
return response.choices[0].message.content.strip()
if __name__ == '__main__':
print('s')

View File

@@ -0,0 +1,16 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
@dataclass
class ModelConfig:
"""
存储每个模型提供商的调用参数信息,用于从数据库读取并动态构建 GPT 调用实例。
"""
name: str # 展示名,如 "GPT-4 Turbo"(用于前端展示)
provider: str # 模型提供商,如 "openai"、"qwen"、"deepseek"
api_key: str # 调用该模型使用的 API Key
base_url: str # 模型 API 接口地址OpenAI SDK兼容
model_name: str # 实际请求用的模型名称,如 "gpt-4-turbo"
created_at: Optional[datetime] = None # 可选:创建时间(从 SQLite 自动生成)

View File

@@ -0,0 +1,16 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
@dataclass
class ProviderModel:
"""
存储每个模型提供商的调用参数信息,用于从数据库读取并动态构建 GPT 调用实例。
"""
id: str # 模型唯一 ID推荐用 UUID
logo: str # 模型图标 URL
name: str # 展示名,如 "GPT-4 Turbo"(用于前端展示)
api_key: str # 调用该模型使用的 API Key
base_url: str # 模型 API 接口地址OpenAI SDK兼容
created_at: Optional[datetime] = None # 可选:创建时间(从 SQLite 自动生成)

View File

@@ -0,0 +1,82 @@
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from app.utils.response import ResponseWrapper as R
from app.services.provider import ProviderService
router = APIRouter()
# ✅ 新增 type 字段
class ProviderRequest(BaseModel):
name: str
api_key: str
base_url: str
logo: str
type: str
class ProviderUpdateRequest(BaseModel):
id: int
name: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
logo: Optional[str] = None
type: Optional[str] = None
@router.post("/add_provider")
def add_provider(data: ProviderRequest):
try:
ProviderService.add_provider(
name=data.name,
api_key=data.api_key,
base_url=data.base_url,
logo=data.logo,
type_=data.type
)
return R.success(msg='添加模型供应商成功')
except Exception as e:
return R.error(msg=e)
@router.get("/get_all_providers")
def get_all_providers():
try:
res = ProviderService.get_all_providers()
return R.success(data=res)
except Exception as e:
return R.error(msg=e)
@router.get("/get_provider_by_id/{id}")
def get_provider_by_id(id: int):
try:
res = ProviderService.get_provider_by_id(id)
return R.success(data=res)
except Exception as e:
return R.error(msg=e)
@router.get("/get_provider_by_name/{name}")
def get_provider_by_name(name: str):
try:
res = ProviderService.get_provider_by_name(name)
return R.success(data=res)
except Exception as e:
return R.error(msg=e)
@router.post("/update_provider/")
def update_provider(data: ProviderUpdateRequest):
try:
if all(
field is None
for field in [data.name, data.api_key, data.base_url, data.logo, data.type]
):
return R.error(msg='请至少填写一个参数')
ProviderService.update_provider(
id=data.id,
name=data.name or '',
api_key=data.api_key or '',
base_url=data.base_url or '',
logo=data.logo or '',
type_=data.type or ''
)
return R.success(msg='更新模型供应商成功')
except Exception as e:
return R.error(msg=e)

View File

@@ -0,0 +1,23 @@
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
class ModelService:
@staticmethod
def get_model_list(provider_id: int):
provider=ProviderService.get_provider_by_id(provider_id)
if not provider:
return []
config=ModelConfig(
api_key=provider.api_key,
base_url=provider.base_url,
provider=provider.name,
model_name='',
name=provider.name,
)
GPT=GPTFactory().from_config(config)
return GPT.list_models()
if __name__ == '__main__':
print(ModelService.get_model_list(1))

View File

@@ -53,21 +53,18 @@ class NoteGenerator:
def get_gpt(self) -> GPT:
self.provider = self.provider.lower()
if self.provider == 'openai':
logger.info("使用OpenAI")
return OpenaiGPT()
elif self.provider == 'deepseek':
elif self.provider == 'deepSeek':
logger.info("使用DeepSeek")
return DeepSeekGPT()
elif self.provider == 'qwen':
logger.info("使用Qwen")
return QwenGPT()
else:
self.provider = 'openai'
logger.warning("不支持的AI提供商,使用 OpenAI 做完GPT")
return OpenaiGPT()
logger.warning("不支持的AI提供商")
raise ValueError(f"不支持的AI提供商{self.provider}")
def get_downloader(self, platform: str) -> Downloader:
if platform == "bilibili":
@@ -162,9 +159,9 @@ class NoteGenerator:
# 1. 选择下载器
downloader = self.get_downloader(platform)
gpt = self.get_gpt()
logger.info(f'使用{downloader.__class__.__name__}下载器')
logger.info(f'使用{gpt.__class__.__name__}GPT')
logger.info(f'视频地址:{video_url}')
logger.info(f'使用{downloader.__class__.__name__}下载器\n'
f'使用{gpt.__class__.__name__}GPT\n'
f'视频地址:{video_url}')
if screenshot:
video_path = downloader.download_video(video_url)

View File

@@ -0,0 +1,54 @@
from app.db.provider_dao import (
insert_provider,
init_provider_table,
get_all_providers,
get_provider_by_name,
get_provider_by_id,
update_provider,
delete_provider,
)
class ProviderService:
@staticmethod
def add_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
return insert_provider(name, api_key, base_url, logo, type_)
@staticmethod
def get_all_providers():
provider_list = []
provider = get_all_providers()
for i in provider:
provider_list.append({
"id": i[0],
"name": i[1],
"logo": i[2],
"type": i[3], # ✅ 加上类型
"api_key": i[4],
"base_url": i[5],
})
return provider_list
@staticmethod
def get_provider_by_name(name: str):
return get_provider_by_name(name)
@staticmethod
def get_provider_by_id(id: int):
return get_provider_by_id(id)
@staticmethod
def update_provider(
id: int,
name: str,
api_key: str,
base_url: str,
logo: str,
type_: str
):
return update_provider(id, name, api_key, base_url, logo, type_)
@staticmethod
def delete_provider(id: int):
return delete_provider(id)

View File

@@ -1,250 +0,0 @@
import json
import logging
import time
from typing import Optional, List, Dict, Union
import requests
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from events import transcription_finished
__version__ = "0.0.3"
API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface"
# 申请上传
API_REQ_UPLOAD = API_BASE_URL + "/resource/create"
# 提交上传
API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete"
# 创建任务
API_CREATE_TASK = API_BASE_URL + "/task"
# 查询结果
API_QUERY_RESULT = API_BASE_URL + "/task/result"
logger = get_logger(__name__)
class BcutTranscriber(Transcriber):
"""必剪 语音识别接口"""
headers = {
'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)',
'Content-Type': 'application/json'
}
def __init__(self):
self.session = requests.Session()
self.task_id = None
self.__etags = []
self.__in_boss_key: Optional[str] = None
self.__resource_id: Optional[str] = None
self.__upload_id: Optional[str] = None
self.__upload_urls: List[str] = []
self.__per_size: Optional[int] = None
self.__clips: Optional[int] = None
self.__etags: List[str] = []
self.__download_url: Optional[str] = None
self.task_id: Optional[str] = None
def _load_file(self, file_path: str) -> bytes:
"""读取文件内容"""
with open(file_path, 'rb') as f:
return f.read()
def _upload(self, file_path: str) -> None:
"""申请上传"""
file_binary = self._load_file(file_path)
if not file_binary:
raise ValueError("无法读取文件数据")
payload = json.dumps({
"type": 2,
"name": "audio.mp3",
"size": len(file_binary),
"ResourceFileType": "mp3",
"model_id": "8",
})
resp = self.session.post(
API_REQ_UPLOAD,
data=payload,
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
resp_data = resp["data"]
self.__in_boss_key = resp_data["in_boss_key"]
self.__resource_id = resp_data["resource_id"]
self.__upload_id = resp_data["upload_id"]
self.__upload_urls = resp_data["upload_urls"]
self.__per_size = resp_data["per_size"]
self.__clips = len(resp_data["upload_urls"])
logger.info(
f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}"
)
self.__upload_part(file_binary)
self.__commit_upload()
def __upload_part(self, file_binary: bytes) -> None:
"""上传音频数据"""
for clip in range(self.__clips):
start_range = clip * self.__per_size
end_range = min((clip + 1) * self.__per_size, len(file_binary))
logger.info(f"开始上传分片{clip}: {start_range}-{end_range}")
resp = self.session.put(
self.__upload_urls[clip],
data=file_binary[start_range:end_range],
headers={'Content-Type': 'application/octet-stream'}
)
resp.raise_for_status()
etag = resp.headers.get("Etag", "").strip('"')
self.__etags.append(etag)
logger.info(f"分片{clip}上传成功: {etag}")
def __commit_upload(self) -> None:
"""提交上传数据"""
data = json.dumps({
"InBossKey": self.__in_boss_key,
"ResourceId": self.__resource_id,
"Etags": ",".join(self.__etags),
"UploadId": self.__upload_id,
"model_id": "8",
})
resp = self.session.post(
API_COMMIT_UPLOAD,
data=data,
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
if resp.get("code") != 0:
error_msg = f"上传提交失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
self.__download_url = resp["data"]["download_url"]
logger.info(f"提交成功,下载链接: {self.__download_url}")
def _create_task(self) -> str:
"""开始创建转换任务"""
resp = self.session.post(
API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
if resp.get("code") != 0:
error_msg = f"创建任务失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
self.task_id = resp["data"]["task_id"]
logger.info(f"任务已创建: {self.task_id}")
return self.task_id
def _query_result(self) -> dict:
"""查询转换结果"""
resp = self.session.get(
API_QUERY_RESULT,
params={"model_id": 7, "task_id": self.task_id},
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
if resp.get("code") != 0:
error_msg = f"查询结果失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
return resp["data"]
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
"""执行识别过程,符合 Transcriber 接口"""
try:
logger.info(f"开始处理文件: {file_path}")
# 上传文件
logger.info("正在上传文件...")
self._upload(file_path)
# 创建任务
logger.info("提交转录任务...")
self._create_task()
# 轮询检查任务状态
logger.info("等待转录结果...")
task_resp = None
max_retries = 500
for i in range(max_retries):
task_resp = self._query_result()
if task_resp["state"] == 4: # 完成状态
break
elif task_resp["state"] == 3: # 失败状态
error_msg = f"B站ASR任务失败状态码: {task_resp['state']}"
logger.error(error_msg)
raise Exception(error_msg)
# 每隔一段时间打印进度
if i % 10 == 0:
logger.info(f"转录进行中... {i}/{max_retries}")
time.sleep(1)
if not task_resp or task_resp["state"] != 4:
error_msg = f"B站ASR任务未能完成状态: {task_resp.get('state') if task_resp else 'Unknown'}"
logger.error(error_msg)
raise Exception(error_msg)
# 解析结果
logger.info("转录成功,处理结果...")
result_json = json.loads(task_resp["result"])
# 提取分段数据
segments = []
full_text = ""
for u in result_json.get("utterances", []):
text = u.get("transcript", "").strip()
# B站ASR返回的时间戳是毫秒需要转换为秒
start_time = float(u.get("start_time", 0)) / 1000.0
end_time = float(u.get("end_time", 0)) / 1000.0
full_text += text + " "
segments.append(TranscriptSegment(
start=start_time,
end=end_time,
text=text
))
# 创建结果对象
result = TranscriptResult(
language=result_json.get("language", "zh"),
full_text=full_text.strip(),
segments=segments,
raw=result_json
)
# 触发完成事件
self.on_finish(file_path, result)
return result
except Exception as e:
logger.error(f"B站ASR处理失败: {str(e)}")
raise
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
"""转录完成的回调"""
logger.info(f"B站ASR转写完成: {video_path}")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -1,115 +0,0 @@
import requests
import logging
import os
from typing import Union, List, Dict, Optional
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from events import transcription_finished
logger = get_logger(__name__)
class KuaishouTranscriber(Transcriber):
"""快手语音识别实现"""
API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate"
def __init__(self):
pass
def _load_file(self, file_path: str) -> bytes:
"""读取文件内容"""
with open(file_path, 'rb') as f:
return f.read()
def _submit(self, file_path: str) -> dict:
"""提交识别请求"""
try:
file_binary = self._load_file(file_path)
payload = {
"typeId": "1"
}
# 使用文件名作为上传文件名
file_name = os.path.basename(file_path)
files = [('file', (file_name, file_binary, 'audio/mpeg'))]
logger.info(f"开始向快手API提交请求文件: {file_name}")
response = requests.post(self.API_URL, data=payload, files=files, timeout=300)
response.raise_for_status() # 检查HTTP错误
result = response.json()
# 检查快手API返回是否包含错误
if "data" not in result or result.get("code", 0) != 0:
error_msg = f"快手API返回错误: {result.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
return result
except requests.exceptions.RequestException as e:
error_msg = f"快手ASR请求网络错误: {str(e)}"
logger.error(error_msg)
raise
except Exception as e:
error_msg = f"快手ASR请求处理错误: {str(e)}"
logger.error(error_msg)
raise
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
"""执行转录过程,符合 Transcriber 接口"""
try:
logger.info(f"开始处理文件: {file_path}")
# 提交请求并获取结果
logger.info("向快手API提交识别请求...")
result_data = self._submit(file_path)
logger.info("请求成功,处理结果...")
# 提取分段数据
segments = []
full_text = ""
# 解析快手API返回的文本段
texts = result_data.get('data', {}).get('text', [])
for u in texts:
text = u.get('text', '').strip()
start_time = float(u.get('start_time', 0))
end_time = float(u.get('end_time', 0))
full_text += text + " "
segments.append(TranscriptSegment(
start=start_time,
end=end_time,
text=text
))
# 创建结果对象
result = TranscriptResult(
language="zh", # 快手API可能不返回语言信息默认为中文
full_text=full_text.strip(),
segments=segments,
raw=result_data
)
# 触发完成事件
self.on_finish(file_path, result)
return result
except Exception as e:
logger.error(f"快手ASR处理失败: {str(e)}")
raise
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
"""转录完成的回调"""
logger.info(f"快手ASR转写完成: {video_path}")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -1,74 +1,19 @@
from app.transcriber.whisper import WhisperTranscriber
from app.transcriber.bcut import BcutTranscriber
from app.transcriber.kuaishou import KuaishouTranscriber
from app.utils.logger import get_logger
logger = get_logger(__name__)
logger.info('初始化转录服务提供器')
logger.info('实例化transcriber')
# TODO:后面需要加入逻辑选择
_transcriber = None
# 维护各种转录器的单例实例
_transcribers = {
'whisper': None,
'bcut': None,
'kuaishou': None
}
def get_transcriber(model_size="base", device="cuda"):
global _transcriber
def get_whisper_transcriber(model_size="base", device="cuda"):
"""获取 Whisper 转录器实例"""
if _transcribers['whisper'] is None:
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
if _transcriber is None:
logger.info('不存在 transcriber 开始实例化transcriber。')
try:
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
logger.info('Whisper 转录器创建成功')
_transcriber = WhisperTranscriber(model_size=model_size, device=device)
logger.info(f'实例化transcriber成功。参数{model_size}, {device} ')
except Exception as e:
logger.error(f"Whisper 转录器创建失败: {e}")
raise
return _transcribers['whisper']
def get_bcut_transcriber():
"""获取 Bcut 转录器实例"""
if _transcribers['bcut'] is None:
logger.info('创建 Bcut 转录器实例')
try:
_transcribers['bcut'] = BcutTranscriber()
logger.info('Bcut 转录器创建成功')
except Exception as e:
logger.error(f"Bcut 转录器创建失败: {e}")
raise
return _transcribers['bcut']
def get_kuaishou_transcriber():
"""获取快手转录器实例"""
if _transcribers['kuaishou'] is None:
logger.info('创建快手转录器实例')
try:
_transcribers['kuaishou'] = KuaishouTranscriber()
logger.info('快手转录器创建成功')
except Exception as e:
logger.error(f"快手转录器创建失败: {e}")
raise
return _transcribers['kuaishou']
def get_transcriber(transcriber_type="whisper", model_size="base", device="cuda"):
"""
获取指定类型的转录器实例
参数:
transcriber_type: 转录器类型,支持 "whisper", "bcut", "kuaishou"
model_size: 模型大小whisper 特有参数
device: 设备类型whisper 特有参数
返回:
对应类型的转录器实例
"""
logger.info(f'获取转录器,类型: {transcriber_type}')
if transcriber_type == "whisper":
return get_whisper_transcriber(model_size, device)
elif transcriber_type == "bcut":
return get_bcut_transcriber()
elif transcriber_type == "kuaishou":
return get_kuaishou_transcriber()
else:
logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper')
return get_whisper_transcriber(model_size, device)
logger.error(f"实例化transcriber失败请检查是否安装whisper。{e}")
return _transcriber

View File

@@ -4,19 +4,14 @@ from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.env_checker import is_cuda_available, is_torch_installed
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
from pathlib import Path
import os
from tqdm import tqdm
from huggingface_hub import snapshot_download
'''
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
'''
logger=get_logger(__name__)
class WhisperTranscriber(Transcriber):
# TODO:修改为可配置
@@ -36,25 +31,15 @@ class WhisperTranscriber(Transcriber):
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
model_dir = get_model_dir("whisper")
model_path = os.path.join(model_dir, f"whisper-{model_size}")
if not Path(model_path).exists():
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
repo_id = f"guillaumekln/faster-whisper-{model_size}"
snapshot_download(
repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
model_path = get_model_dir("whisper")
self.model = WhisperModel(
model_size,
device=self.device,
compute_type=self.compute_type,
# compute_type="int8", # 或 "float16"
cpu_threads=cpu_threads,
download_root=model_dir
download_root=model_path
)
@staticmethod
def is_torch_installed() -> bool:
try:

View File

@@ -3,6 +3,8 @@ import os
import uvicorn
from starlette.staticfiles import StaticFiles
from dotenv import load_dotenv
from app.db.provider_dao import init_provider_table
from app.utils.logger import get_logger
from app import create_app
from app.db.video_task_dao import init_video_task_table
@@ -36,6 +38,7 @@ async def startup_event():
ensure_ffmpeg_or_raise()
get_transcriber()
init_video_task_table()
init_provider_table()
if __name__ == "__main__":
port = int(os.getenv("BACKEND_PORT", 8000))

Binary file not shown.