mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-19 11:49:30 +08:00
:feat 新增模型配置页面和相关功能
- 新增模型配置页面组件和路由 - 实现模型配置表单和相关逻辑- 添加全局配置入口和功能- 优化首页布局和样式- 新增 404 页面组件 - 更新部分组件样式和结构
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from fastapi import FastAPI
|
||||
from .routers import note
|
||||
from .routers import note, provider
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(title="BiliNote")
|
||||
app.include_router(note.router, prefix="/api")
|
||||
app.include_router(provider.router, prefix="/api")
|
||||
return app
|
||||
|
||||
131
backend/app/db/provider_dao.py
Normal file
131
backend/app/db/provider_dao.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from app.db.sqlite_client import get_connection
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def init_provider_table():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS providers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
logo TEXT NOT NULL,
|
||||
type TEXT NOT NULL, -- ✅ 新增字段
|
||||
api_key TEXT NOT NULL,
|
||||
base_url TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("provider table created successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create provider table: {e}")
|
||||
def insert_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO providers (name, api_key, base_url, logo, type)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (name, api_key, base_url, logo, type_))
|
||||
try:
|
||||
conn.commit()
|
||||
cursor_id = cursor.lastrowid
|
||||
conn.close()
|
||||
logger.info(f"Provider inserted successfully. name: {name}, type: {type_}")
|
||||
return cursor_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert provider: {e}")
|
||||
return None
|
||||
def get_provider_by_name(name: str):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers WHERE name = ?", (name,))
|
||||
try:
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if row is None:
|
||||
logger.info(f"Provider not found: {name}")
|
||||
return None
|
||||
logger.info(f"Provider found: {row}")
|
||||
return row
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider by name: {e}")
|
||||
|
||||
def get_provider_by_id(id: int):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers WHERE id = ?", (id,))
|
||||
try:
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if row is None:
|
||||
logger.info(f"Provider not found: {id}")
|
||||
return None
|
||||
logger.info(f"Provider found: {row}")
|
||||
return row
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider by id: {e}")
|
||||
|
||||
def get_all_providers():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers")
|
||||
try:
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
if rows is None:
|
||||
logger.info("No providers found")
|
||||
return None
|
||||
logger.info(f"Providers found: {rows}")
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all providers: {e}")
|
||||
|
||||
def update_provider(id: int, name: str, api_key: str, base_url: str, logo: str, type_: str):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE providers
|
||||
SET name = ?, api_key = ?, base_url = ?, logo = ?, type = ?
|
||||
WHERE id = ?
|
||||
""", (name, api_key, base_url, logo, type_, id))
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Provider updated successfully. id: {id}, type: {type_}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update provider: {e}")
|
||||
def delete_provider(id: int):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM providers WHERE id = ?", (id,))
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Provider deleted successfully. id: {id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete provider: {e}")
|
||||
@@ -1,4 +1,4 @@
|
||||
import sqlite3
|
||||
|
||||
def get_connection():
|
||||
return sqlite3.connect("note_tasks.db")
|
||||
return sqlite3.connect("bili_note.db")
|
||||
|
||||
@@ -31,6 +31,13 @@ class BilibiliDownloader(Downloader, ABC):
|
||||
ydl_opts = {
|
||||
'format': 'bestaudio[ext=m4a]/bestaudio/best',
|
||||
'outtmpl': output_path,
|
||||
'postprocessors': [
|
||||
{
|
||||
'key': 'FFmpegExtractAudio',
|
||||
'preferredcodec': 'mp3',
|
||||
'preferredquality': '64',
|
||||
}
|
||||
],
|
||||
'noplaylist': True,
|
||||
'quiet': False,
|
||||
}
|
||||
@@ -41,7 +48,7 @@ class BilibiliDownloader(Downloader, ABC):
|
||||
title = info.get("title")
|
||||
duration = info.get("duration", 0)
|
||||
cover_url = info.get("thumbnail")
|
||||
audio_path = os.path.join(output_dir, f"{video_id}.m4a")
|
||||
audio_path = os.path.join(output_dir, f"{video_id}.mp3")
|
||||
|
||||
return AudioDownloadResult(
|
||||
file_path=audio_path,
|
||||
@@ -69,7 +76,7 @@ class BilibiliDownloader(Downloader, ABC):
|
||||
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
|
||||
|
||||
ydl_opts = {
|
||||
'format': 'bv*+ba/bestvideo+bestaudio/best',
|
||||
'format': 'bv*[ext=mp4]/bestvideo+bestaudio/best',
|
||||
'outtmpl': output_path,
|
||||
'noplaylist': True,
|
||||
'quiet': False,
|
||||
|
||||
@@ -27,6 +27,13 @@ class DouyinDownloader(Downloader, ABC):
|
||||
ydl_opts = {
|
||||
'format': 'bestaudio[ext=m4a]/bestaudio/best',
|
||||
'outtmpl': output_path,
|
||||
'postprocessors': [
|
||||
{
|
||||
'key': 'FFmpegExtractAudio',
|
||||
'preferredcodec': 'mp3',
|
||||
'preferredquality': '64',
|
||||
}
|
||||
],
|
||||
'noplaylist': True,
|
||||
'quiet': False,
|
||||
}
|
||||
@@ -37,7 +44,7 @@ class DouyinDownloader(Downloader, ABC):
|
||||
title = info.get("title")
|
||||
duration = info.get("duration", 0)
|
||||
cover_url = info.get("thumbnail")
|
||||
audio_path = os.path.join(output_dir, f"{video_id}.m4a")
|
||||
audio_path = os.path.join(output_dir, f"{video_id}.mp3")
|
||||
|
||||
return AudioDownloadResult(
|
||||
file_path=audio_path,
|
||||
|
||||
@@ -42,7 +42,7 @@ class YoutubeDownloader(Downloader, ABC):
|
||||
title = info.get("title")
|
||||
duration = info.get("duration", 0)
|
||||
cover_url = info.get("thumbnail")
|
||||
audio_path = os.path.join(output_dir, f"{video_id}.m4a")
|
||||
audio_path = os.path.join(output_dir, f"{video_id}.mp3")
|
||||
|
||||
return AudioDownloadResult(
|
||||
file_path=audio_path,
|
||||
|
||||
@@ -10,4 +10,8 @@ class GPT(ABC):
|
||||
:param source:
|
||||
:return:
|
||||
'''
|
||||
pass
|
||||
def create_messages(self, segments:list,**kwargs)->list:
|
||||
pass
|
||||
def list_models(self):
|
||||
pass
|
||||
13
backend/app/gpt/gpt_factory.py
Normal file
13
backend/app/gpt/gpt_factory.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from openai import OpenAI
|
||||
|
||||
from app.gpt.base import GPT
|
||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||
from app.gpt.universal_gpt import UniversalGPT
|
||||
from app.models.model_config import ModelConfig
|
||||
|
||||
|
||||
class GPTFactory:
|
||||
@staticmethod
|
||||
def from_config(config: ModelConfig) -> GPT:
|
||||
client = OpenAICompatibleProvider(api_key=config.api_key, base_url=config.base_url).get_client()
|
||||
return UniversalGPT(client=client, model=config.model_name)
|
||||
@@ -2,6 +2,7 @@ from typing import List
|
||||
from app.gpt.base import GPT
|
||||
from openai import OpenAI
|
||||
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
|
||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||
from app.gpt.utils import fix_markdown
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.models.transcriber_model import TranscriptSegment
|
||||
@@ -15,7 +16,7 @@ class OpenaiGPT(GPT):
|
||||
self.base_url = getenv("OPENAI_API_BASE_URL")
|
||||
self.model=getenv('OPENAI_MODEL')
|
||||
print(self.model)
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
self.client = OpenAICompatibleProvider(api_key=self.api_key, base_url=self.base_url)
|
||||
self.screenshot = False
|
||||
self.link=False
|
||||
|
||||
@@ -49,17 +50,20 @@ class OpenaiGPT(GPT):
|
||||
|
||||
print(content)
|
||||
return [{"role": "user", "content": content + AI_SUM}]
|
||||
|
||||
def list_models(self):
|
||||
return self.client.list_models()
|
||||
def summarize(self, source: GPTSource) -> str:
|
||||
self.screenshot = source.screenshot
|
||||
self.link = source.link
|
||||
source.segment = self.ensure_segments_type(source.segment)
|
||||
messages = self.create_messages(source.segment, source.title,source.tags)
|
||||
response = self.client.chat.completions.create(
|
||||
response = self.client.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.7
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = OpenaiGPT()
|
||||
print(gpt.list_models())
|
||||
|
||||
22
backend/app/gpt/provider/OpenAI_compatible_provider.py
Normal file
22
backend/app/gpt/provider/OpenAI_compatible_provider.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class OpenAICompatibleProvider:
|
||||
def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None):
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def test_connection(api_key: str, base_url: str) -> bool:
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
client.models.list()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error connecting to OpenAI API: {e}")
|
||||
return False
|
||||
@@ -2,6 +2,7 @@ from typing import List
|
||||
from app.gpt.base import GPT
|
||||
from openai import OpenAI
|
||||
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT
|
||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||
from app.gpt.utils import fix_markdown
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.models.transcriber_model import TranscriptSegment
|
||||
@@ -15,7 +16,7 @@ class QwenGPT(GPT):
|
||||
self.base_url = getenv("QWEN_API_BASE_URL")
|
||||
self.model=getenv('QWEN_MODEL')
|
||||
print(self.model)
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
self.client = OpenAICompatibleProvider(api_key=self.api_key, base_url=self.base_url)
|
||||
self.screenshot = False
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
@@ -44,7 +45,8 @@ class QwenGPT(GPT):
|
||||
content += SCREENSHOT
|
||||
print(content)
|
||||
return [{"role": "user", "content": content + AI_SUM}]
|
||||
|
||||
def list_models(self):
|
||||
return self.client.list_models()
|
||||
def summarize(self, source: GPTSource) -> str:
|
||||
self.screenshot = source.screenshot
|
||||
source.segment = self.ensure_segments_type(source.segment)
|
||||
@@ -56,4 +58,6 @@ class QwenGPT(GPT):
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = QwenGPT()
|
||||
print(gpt.list_models())
|
||||
|
||||
17
backend/app/gpt/test.py
Normal file
17
backend/app/gpt/test.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from app.models.model_config import ModelConfig
|
||||
|
||||
if __name__ == '__main__':
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
# 构建模型config
|
||||
config=ModelConfig(
|
||||
id='asas',
|
||||
api_key='',
|
||||
base_url='',
|
||||
model_name="gpt-4o",
|
||||
provider='openai',
|
||||
name='gpt-4o'
|
||||
)
|
||||
# 构建GPT
|
||||
gpt=GPTFactory().from_config(config)
|
||||
|
||||
|
||||
62
backend/app/gpt/universal_gpt.py
Normal file
62
backend/app/gpt/universal_gpt.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from app.gpt.base import GPT
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
|
||||
from app.gpt.utils import fix_markdown
|
||||
from app.models.transcriber_model import TranscriptSegment
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
|
||||
class UniversalGPT(GPT):
|
||||
def __init__(self, client, model: str, temperature: float = 0.7):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.screenshot = False
|
||||
self.screenshot = False
|
||||
self.link = False
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
return str(timedelta(seconds=int(seconds)))[2:]
|
||||
|
||||
def _build_segment_text(self, segments: List[TranscriptSegment]) -> str:
|
||||
return "\n".join(
|
||||
f"{self._format_time(seg.start)} - {seg.text.strip()}"
|
||||
for seg in segments
|
||||
)
|
||||
|
||||
def ensure_segments_type(self, segments) -> List[TranscriptSegment]:
|
||||
return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments]
|
||||
|
||||
def create_messages(self, segments: List[TranscriptSegment],**kwargs):
|
||||
content = BASE_PROMPT.format(
|
||||
video_title=kwargs.get('title'),
|
||||
segment_text=self._build_segment_text(segments),
|
||||
tags=kwargs.get('tags')
|
||||
)
|
||||
if self.screenshot:
|
||||
print(":需要截图")
|
||||
content += SCREENSHOT
|
||||
if self.link:
|
||||
print(":需要链接")
|
||||
content += LINK
|
||||
|
||||
print(content)
|
||||
return [{"role": "user", "content": content + AI_SUM}]
|
||||
|
||||
def list_models(self):
|
||||
return self.client.list_models()
|
||||
def summarize(self, source: GPTSource) -> str:
|
||||
self.screenshot = source.screenshot
|
||||
self.link = source.link
|
||||
source.segment = self.ensure_segments_type(source.segment)
|
||||
messages = self.create_messages(source.segment, source.title,source.tags)
|
||||
response = self.client.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.7
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('s')
|
||||
|
||||
16
backend/app/models/model_config.py
Normal file
16
backend/app/models/model_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""
|
||||
存储每个模型提供商的调用参数信息,用于从数据库读取并动态构建 GPT 调用实例。
|
||||
"""
|
||||
name: str # 展示名,如 "GPT-4 Turbo"(用于前端展示)
|
||||
provider: str # 模型提供商,如 "openai"、"qwen"、"deepseek"
|
||||
api_key: str # 调用该模型使用的 API Key
|
||||
base_url: str # 模型 API 接口地址(OpenAI SDK兼容)
|
||||
model_name: str # 实际请求用的模型名称,如 "gpt-4-turbo"
|
||||
created_at: Optional[datetime] = None # 可选:创建时间(从 SQLite 自动生成)
|
||||
16
backend/app/models/provide_model.py
Normal file
16
backend/app/models/provide_model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderModel:
|
||||
"""
|
||||
存储每个模型提供商的调用参数信息,用于从数据库读取并动态构建 GPT 调用实例。
|
||||
"""
|
||||
id: str # 模型唯一 ID(推荐用 UUID)
|
||||
logo: str # 模型图标 URL
|
||||
name: str # 展示名,如 "GPT-4 Turbo"(用于前端展示)
|
||||
api_key: str # 调用该模型使用的 API Key
|
||||
base_url: str # 模型 API 接口地址(OpenAI SDK兼容)
|
||||
created_at: Optional[datetime] = None # 可选:创建时间(从 SQLite 自动生成)
|
||||
82
backend/app/routers/provider.py
Normal file
82
backend/app/routers/provider.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
from app.services.provider import ProviderService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ✅ 新增 type 字段
|
||||
class ProviderRequest(BaseModel):
|
||||
name: str
|
||||
api_key: str
|
||||
base_url: str
|
||||
logo: str
|
||||
type: str
|
||||
|
||||
class ProviderUpdateRequest(BaseModel):
|
||||
id: int
|
||||
name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
logo: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
@router.post("/add_provider")
|
||||
def add_provider(data: ProviderRequest):
|
||||
try:
|
||||
ProviderService.add_provider(
|
||||
name=data.name,
|
||||
api_key=data.api_key,
|
||||
base_url=data.base_url,
|
||||
logo=data.logo,
|
||||
type_=data.type
|
||||
)
|
||||
return R.success(msg='添加模型供应商成功')
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
|
||||
@router.get("/get_all_providers")
|
||||
def get_all_providers():
|
||||
try:
|
||||
res = ProviderService.get_all_providers()
|
||||
return R.success(data=res)
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
|
||||
@router.get("/get_provider_by_id/{id}")
|
||||
def get_provider_by_id(id: int):
|
||||
try:
|
||||
res = ProviderService.get_provider_by_id(id)
|
||||
return R.success(data=res)
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
|
||||
@router.get("/get_provider_by_name/{name}")
|
||||
def get_provider_by_name(name: str):
|
||||
try:
|
||||
res = ProviderService.get_provider_by_name(name)
|
||||
return R.success(data=res)
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
|
||||
@router.post("/update_provider/")
|
||||
def update_provider(data: ProviderUpdateRequest):
|
||||
try:
|
||||
if all(
|
||||
field is None
|
||||
for field in [data.name, data.api_key, data.base_url, data.logo, data.type]
|
||||
):
|
||||
return R.error(msg='请至少填写一个参数')
|
||||
|
||||
ProviderService.update_provider(
|
||||
id=data.id,
|
||||
name=data.name or '',
|
||||
api_key=data.api_key or '',
|
||||
base_url=data.base_url or '',
|
||||
logo=data.logo or '',
|
||||
type_=data.type or ''
|
||||
)
|
||||
return R.success(msg='更新模型供应商成功')
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
23
backend/app/services/model.py
Normal file
23
backend/app/services/model.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.provider import ProviderService
|
||||
|
||||
|
||||
class ModelService:
|
||||
@staticmethod
|
||||
def get_model_list(provider_id: int):
|
||||
provider=ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
return []
|
||||
config=ModelConfig(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
provider=provider.name,
|
||||
model_name='',
|
||||
name=provider.name,
|
||||
)
|
||||
GPT=GPTFactory().from_config(config)
|
||||
return GPT.list_models()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(ModelService.get_model_list(1))
|
||||
@@ -53,21 +53,18 @@ class NoteGenerator:
|
||||
|
||||
|
||||
def get_gpt(self) -> GPT:
|
||||
self.provider = self.provider.lower()
|
||||
if self.provider == 'openai':
|
||||
logger.info("使用OpenAI")
|
||||
return OpenaiGPT()
|
||||
elif self.provider == 'deepseek':
|
||||
elif self.provider == 'deepSeek':
|
||||
logger.info("使用DeepSeek")
|
||||
return DeepSeekGPT()
|
||||
elif self.provider == 'qwen':
|
||||
logger.info("使用Qwen")
|
||||
return QwenGPT()
|
||||
else:
|
||||
self.provider = 'openai'
|
||||
logger.warning("不支持的AI提供商,使用 OpenAI 做完GPT")
|
||||
return OpenaiGPT()
|
||||
|
||||
logger.warning("不支持的AI提供商")
|
||||
raise ValueError(f"不支持的AI提供商:{self.provider}")
|
||||
|
||||
def get_downloader(self, platform: str) -> Downloader:
|
||||
if platform == "bilibili":
|
||||
@@ -162,9 +159,9 @@ class NoteGenerator:
|
||||
# 1. 选择下载器
|
||||
downloader = self.get_downloader(platform)
|
||||
gpt = self.get_gpt()
|
||||
logger.info(f'使用{downloader.__class__.__name__}下载器')
|
||||
logger.info(f'使用{gpt.__class__.__name__}GPT')
|
||||
logger.info(f'视频地址:{video_url}')
|
||||
logger.info(f'使用{downloader.__class__.__name__}下载器\n'
|
||||
f'使用{gpt.__class__.__name__}GPT\n'
|
||||
f'视频地址:{video_url}')
|
||||
if screenshot:
|
||||
|
||||
video_path = downloader.download_video(video_url)
|
||||
|
||||
54
backend/app/services/provider.py
Normal file
54
backend/app/services/provider.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from app.db.provider_dao import (
|
||||
insert_provider,
|
||||
init_provider_table,
|
||||
get_all_providers,
|
||||
get_provider_by_name,
|
||||
get_provider_by_id,
|
||||
update_provider,
|
||||
delete_provider,
|
||||
)
|
||||
|
||||
class ProviderService:
|
||||
|
||||
@staticmethod
|
||||
def add_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
|
||||
return insert_provider(name, api_key, base_url, logo, type_)
|
||||
|
||||
@staticmethod
|
||||
def get_all_providers():
|
||||
provider_list = []
|
||||
provider = get_all_providers()
|
||||
|
||||
for i in provider:
|
||||
provider_list.append({
|
||||
"id": i[0],
|
||||
"name": i[1],
|
||||
"logo": i[2],
|
||||
"type": i[3], # ✅ 加上类型
|
||||
"api_key": i[4],
|
||||
"base_url": i[5],
|
||||
})
|
||||
return provider_list
|
||||
|
||||
@staticmethod
|
||||
def get_provider_by_name(name: str):
|
||||
return get_provider_by_name(name)
|
||||
|
||||
@staticmethod
|
||||
def get_provider_by_id(id: int):
|
||||
return get_provider_by_id(id)
|
||||
|
||||
@staticmethod
|
||||
def update_provider(
|
||||
id: int,
|
||||
name: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
logo: str,
|
||||
type_: str
|
||||
):
|
||||
return update_provider(id, name, api_key, base_url, logo, type_)
|
||||
|
||||
@staticmethod
|
||||
def delete_provider(id: int):
|
||||
return delete_provider(id)
|
||||
@@ -1,250 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
import requests
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
__version__ = "0.0.3"
|
||||
|
||||
API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface"
|
||||
|
||||
# 申请上传
|
||||
API_REQ_UPLOAD = API_BASE_URL + "/resource/create"
|
||||
|
||||
# 提交上传
|
||||
API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete"
|
||||
|
||||
# 创建任务
|
||||
API_CREATE_TASK = API_BASE_URL + "/task"
|
||||
|
||||
# 查询结果
|
||||
API_QUERY_RESULT = API_BASE_URL + "/task/result"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class BcutTranscriber(Transcriber):
|
||||
"""必剪 语音识别接口"""
|
||||
headers = {
|
||||
'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.task_id = None
|
||||
self.__etags = []
|
||||
|
||||
self.__in_boss_key: Optional[str] = None
|
||||
self.__resource_id: Optional[str] = None
|
||||
self.__upload_id: Optional[str] = None
|
||||
self.__upload_urls: List[str] = []
|
||||
self.__per_size: Optional[int] = None
|
||||
self.__clips: Optional[int] = None
|
||||
|
||||
self.__etags: List[str] = []
|
||||
self.__download_url: Optional[str] = None
|
||||
self.task_id: Optional[str] = None
|
||||
|
||||
def _load_file(self, file_path: str) -> bytes:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def _upload(self, file_path: str) -> None:
|
||||
"""申请上传"""
|
||||
file_binary = self._load_file(file_path)
|
||||
if not file_binary:
|
||||
raise ValueError("无法读取文件数据")
|
||||
|
||||
payload = json.dumps({
|
||||
"type": 2,
|
||||
"name": "audio.mp3",
|
||||
"size": len(file_binary),
|
||||
"ResourceFileType": "mp3",
|
||||
"model_id": "8",
|
||||
})
|
||||
|
||||
resp = self.session.post(
|
||||
API_REQ_UPLOAD,
|
||||
data=payload,
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
resp_data = resp["data"]
|
||||
|
||||
self.__in_boss_key = resp_data["in_boss_key"]
|
||||
self.__resource_id = resp_data["resource_id"]
|
||||
self.__upload_id = resp_data["upload_id"]
|
||||
self.__upload_urls = resp_data["upload_urls"]
|
||||
self.__per_size = resp_data["per_size"]
|
||||
self.__clips = len(resp_data["upload_urls"])
|
||||
|
||||
logger.info(
|
||||
f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}"
|
||||
)
|
||||
self.__upload_part(file_binary)
|
||||
self.__commit_upload()
|
||||
|
||||
def __upload_part(self, file_binary: bytes) -> None:
|
||||
"""上传音频数据"""
|
||||
for clip in range(self.__clips):
|
||||
start_range = clip * self.__per_size
|
||||
end_range = min((clip + 1) * self.__per_size, len(file_binary))
|
||||
logger.info(f"开始上传分片{clip}: {start_range}-{end_range}")
|
||||
resp = self.session.put(
|
||||
self.__upload_urls[clip],
|
||||
data=file_binary[start_range:end_range],
|
||||
headers={'Content-Type': 'application/octet-stream'}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
etag = resp.headers.get("Etag", "").strip('"')
|
||||
self.__etags.append(etag)
|
||||
logger.info(f"分片{clip}上传成功: {etag}")
|
||||
|
||||
def __commit_upload(self) -> None:
|
||||
"""提交上传数据"""
|
||||
data = json.dumps({
|
||||
"InBossKey": self.__in_boss_key,
|
||||
"ResourceId": self.__resource_id,
|
||||
"Etags": ",".join(self.__etags),
|
||||
"UploadId": self.__upload_id,
|
||||
"model_id": "8",
|
||||
})
|
||||
resp = self.session.post(
|
||||
API_COMMIT_UPLOAD,
|
||||
data=data,
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"上传提交失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
self.__download_url = resp["data"]["download_url"]
|
||||
logger.info(f"提交成功,下载链接: {self.__download_url}")
|
||||
|
||||
def _create_task(self) -> str:
|
||||
"""开始创建转换任务"""
|
||||
resp = self.session.post(
|
||||
API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"创建任务失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
self.task_id = resp["data"]["task_id"]
|
||||
logger.info(f"任务已创建: {self.task_id}")
|
||||
return self.task_id
|
||||
|
||||
def _query_result(self) -> dict:
|
||||
"""查询转换结果"""
|
||||
resp = self.session.get(
|
||||
API_QUERY_RESULT,
|
||||
params={"model_id": 7, "task_id": self.task_id},
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"查询结果失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return resp["data"]
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
"""执行识别过程,符合 Transcriber 接口"""
|
||||
try:
|
||||
logger.info(f"开始处理文件: {file_path}")
|
||||
|
||||
# 上传文件
|
||||
logger.info("正在上传文件...")
|
||||
self._upload(file_path)
|
||||
|
||||
# 创建任务
|
||||
logger.info("提交转录任务...")
|
||||
self._create_task()
|
||||
|
||||
# 轮询检查任务状态
|
||||
logger.info("等待转录结果...")
|
||||
task_resp = None
|
||||
max_retries = 500
|
||||
for i in range(max_retries):
|
||||
task_resp = self._query_result()
|
||||
|
||||
if task_resp["state"] == 4: # 完成状态
|
||||
break
|
||||
elif task_resp["state"] == 3: # 失败状态
|
||||
error_msg = f"B站ASR任务失败,状态码: {task_resp['state']}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 每隔一段时间打印进度
|
||||
if i % 10 == 0:
|
||||
logger.info(f"转录进行中... {i}/{max_retries}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if not task_resp or task_resp["state"] != 4:
|
||||
error_msg = f"B站ASR任务未能完成,状态: {task_resp.get('state') if task_resp else 'Unknown'}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 解析结果
|
||||
logger.info("转录成功,处理结果...")
|
||||
result_json = json.loads(task_resp["result"])
|
||||
|
||||
# 提取分段数据
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
for u in result_json.get("utterances", []):
|
||||
text = u.get("transcript", "").strip()
|
||||
# B站ASR返回的时间戳是毫秒,需要转换为秒
|
||||
start_time = float(u.get("start_time", 0)) / 1000.0
|
||||
end_time = float(u.get("end_time", 0)) / 1000.0
|
||||
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=text
|
||||
))
|
||||
|
||||
# 创建结果对象
|
||||
result = TranscriptResult(
|
||||
language=result_json.get("language", "zh"),
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result_json
|
||||
)
|
||||
|
||||
# 触发完成事件
|
||||
self.on_finish(file_path, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"B站ASR处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
"""转录完成的回调"""
|
||||
logger.info(f"B站ASR转写完成: {video_path}")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
@@ -1,115 +0,0 @@
|
||||
import requests
|
||||
import logging
|
||||
import os
|
||||
from typing import Union, List, Dict, Optional
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class KuaishouTranscriber(Transcriber):
|
||||
"""快手语音识别实现"""
|
||||
|
||||
API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _load_file(self, file_path: str) -> bytes:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def _submit(self, file_path: str) -> dict:
|
||||
"""提交识别请求"""
|
||||
try:
|
||||
file_binary = self._load_file(file_path)
|
||||
|
||||
payload = {
|
||||
"typeId": "1"
|
||||
}
|
||||
|
||||
# 使用文件名作为上传文件名
|
||||
file_name = os.path.basename(file_path)
|
||||
files = [('file', (file_name, file_binary, 'audio/mpeg'))]
|
||||
|
||||
logger.info(f"开始向快手API提交请求,文件: {file_name}")
|
||||
response = requests.post(self.API_URL, data=payload, files=files, timeout=300)
|
||||
response.raise_for_status() # 检查HTTP错误
|
||||
|
||||
result = response.json()
|
||||
|
||||
# 检查快手API返回是否包含错误
|
||||
if "data" not in result or result.get("code", 0) != 0:
|
||||
error_msg = f"快手API返回错误: {result.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"快手ASR请求网络错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"快手ASR请求处理错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
"""执行转录过程,符合 Transcriber 接口"""
|
||||
try:
|
||||
logger.info(f"开始处理文件: {file_path}")
|
||||
|
||||
# 提交请求并获取结果
|
||||
logger.info("向快手API提交识别请求...")
|
||||
result_data = self._submit(file_path)
|
||||
|
||||
logger.info("请求成功,处理结果...")
|
||||
|
||||
# 提取分段数据
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
# 解析快手API返回的文本段
|
||||
texts = result_data.get('data', {}).get('text', [])
|
||||
for u in texts:
|
||||
text = u.get('text', '').strip()
|
||||
start_time = float(u.get('start_time', 0))
|
||||
end_time = float(u.get('end_time', 0))
|
||||
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=text
|
||||
))
|
||||
|
||||
# 创建结果对象
|
||||
result = TranscriptResult(
|
||||
language="zh", # 快手API可能不返回语言信息,默认为中文
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result_data
|
||||
)
|
||||
|
||||
# 触发完成事件
|
||||
self.on_finish(file_path, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"快手ASR处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
"""转录完成的回调"""
|
||||
logger.info(f"快手ASR转写完成: {video_path}")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
@@ -1,74 +1,19 @@
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
from app.transcriber.bcut import BcutTranscriber
|
||||
from app.transcriber.kuaishou import KuaishouTranscriber
|
||||
from app.utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info('初始化转录服务提供器')
|
||||
logger.info('实例化transcriber')
|
||||
# TODO:后面需要加入逻辑选择
|
||||
_transcriber = None
|
||||
|
||||
# 维护各种转录器的单例实例
|
||||
_transcribers = {
|
||||
'whisper': None,
|
||||
'bcut': None,
|
||||
'kuaishou': None
|
||||
}
|
||||
def get_transcriber(model_size="base", device="cuda"):
|
||||
global _transcriber
|
||||
|
||||
def get_whisper_transcriber(model_size="base", device="cuda"):
|
||||
"""获取 Whisper 转录器实例"""
|
||||
if _transcribers['whisper'] is None:
|
||||
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
|
||||
if _transcriber is None:
|
||||
logger.info('不存在 transcriber ,开始实例化transcriber。')
|
||||
try:
|
||||
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info('Whisper 转录器创建成功')
|
||||
_transcriber = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info(f'实例化transcriber成功。参数:{model_size}, {device} ')
|
||||
except Exception as e:
|
||||
logger.error(f"Whisper 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['whisper']
|
||||
|
||||
def get_bcut_transcriber():
|
||||
"""获取 Bcut 转录器实例"""
|
||||
if _transcribers['bcut'] is None:
|
||||
logger.info('创建 Bcut 转录器实例')
|
||||
try:
|
||||
_transcribers['bcut'] = BcutTranscriber()
|
||||
logger.info('Bcut 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"Bcut 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['bcut']
|
||||
|
||||
def get_kuaishou_transcriber():
|
||||
"""获取快手转录器实例"""
|
||||
if _transcribers['kuaishou'] is None:
|
||||
logger.info('创建快手转录器实例')
|
||||
try:
|
||||
_transcribers['kuaishou'] = KuaishouTranscriber()
|
||||
logger.info('快手转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"快手转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['kuaishou']
|
||||
|
||||
def get_transcriber(transcriber_type="whisper", model_size="base", device="cuda"):
|
||||
"""
|
||||
获取指定类型的转录器实例
|
||||
|
||||
参数:
|
||||
transcriber_type: 转录器类型,支持 "whisper", "bcut", "kuaishou"
|
||||
model_size: 模型大小,whisper 特有参数
|
||||
device: 设备类型,whisper 特有参数
|
||||
|
||||
返回:
|
||||
对应类型的转录器实例
|
||||
"""
|
||||
logger.info(f'获取转录器,类型: {transcriber_type}')
|
||||
|
||||
if transcriber_type == "whisper":
|
||||
return get_whisper_transcriber(model_size, device)
|
||||
elif transcriber_type == "bcut":
|
||||
return get_bcut_transcriber()
|
||||
elif transcriber_type == "kuaishou":
|
||||
return get_kuaishou_transcriber()
|
||||
else:
|
||||
logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper')
|
||||
return get_whisper_transcriber(model_size, device)
|
||||
logger.error(f"实例化transcriber失败,请检查是否安装whisper。{e}")
|
||||
return _transcriber
|
||||
@@ -4,19 +4,14 @@ from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.path_helper import get_model_dir
|
||||
|
||||
from events import transcription_finished
|
||||
from pathlib import Path
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
'''
|
||||
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
|
||||
'''
|
||||
logger=get_logger(__name__)
|
||||
|
||||
|
||||
class WhisperTranscriber(Transcriber):
|
||||
# TODO:修改为可配置
|
||||
@@ -36,25 +31,15 @@ class WhisperTranscriber(Transcriber):
|
||||
|
||||
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
|
||||
|
||||
model_dir = get_model_dir("whisper")
|
||||
model_path = os.path.join(model_dir, f"whisper-{model_size}")
|
||||
if not Path(model_path).exists():
|
||||
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
|
||||
repo_id = f"guillaumekln/faster-whisper-{model_size}"
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
local_dir=model_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
logger.info("模型下载完成")
|
||||
|
||||
model_path = get_model_dir("whisper")
|
||||
self.model = WhisperModel(
|
||||
model_size,
|
||||
device=self.device,
|
||||
compute_type=self.compute_type,
|
||||
# compute_type="int8", # 或 "float16"
|
||||
cpu_threads=cpu_threads,
|
||||
download_root=model_dir
|
||||
download_root=model_path
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_torch_installed() -> bool:
|
||||
try:
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
import uvicorn
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.db.provider_dao import init_provider_table
|
||||
from app.utils.logger import get_logger
|
||||
from app import create_app
|
||||
from app.db.video_task_dao import init_video_task_table
|
||||
@@ -36,6 +38,7 @@ async def startup_event():
|
||||
ensure_ffmpeg_or_raise()
|
||||
get_transcriber()
|
||||
init_video_task_table()
|
||||
init_provider_table()
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("BACKEND_PORT", 8000))
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user