feat: 新增模型管理和供应商配置功能

### v1.1.0
- #### Added
  - 新增 AI 笔记风格选择
  - 新增 AI 笔记返回格式选择
  - 添加 AI 自定义笔记备注 Prompt
  - 添加任务失败重试
  - 添加全局设置页,可在设置页进行模型设置

- #### Optimize
  - 优化前端样式,优化用户体验
  - 增加生成中间产物,可用于失败后加快生成速度
- #### Fix
  - 修复视频截图视频过早删除错误
This commit is contained in:
思诺特
2025-04-26 23:40:17 +08:00
parent 1323cfd1ec
commit 171dea5e0d
51 changed files with 2511 additions and 414 deletions

View File

@@ -1,23 +1,109 @@
from app.db.model_dao import insert_model, get_all_models
from app.db.provider_dao import get_enabled_providers
from app.gpt.gpt_factory import GPTFactory
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
class ModelService:
@staticmethod
def get_model_list(provider_id: int):
provider=ProviderService.get_provider_by_id(provider_id)
def _build_model_config(provider: dict) -> ModelConfig:
return ModelConfig(
api_key=provider["api_key"],
base_url=provider["base_url"],
provider=provider["name"],
model_name='',
name=provider["name"],
)
@staticmethod
def get_model_list(provider_id: int, verbose: bool = False):
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
return []
config=ModelConfig(
api_key=provider.api_key,
base_url=provider.base_url,
provider=provider.name,
model_name='',
name=provider.name,
)
GPT=GPTFactory().from_config(config)
return GPT.list_models()
try:
config = ModelService._build_model_config(provider)
gpt = GPTFactory().from_config(config)
models = gpt.list_models()
if verbose:
print(f"[{provider['name']}] 模型列表: {models}")
return models
except Exception as e:
print(f"[{provider['name']}] 获取模型失败: {e}")
return []
@staticmethod
def get_all_models(verbose: bool = False):
try:
raw_models = get_all_models()
if verbose:
print(f"所有模型列表: {raw_models}")
return ModelService._format_models(raw_models)
except Exception as e:
print(f"获取所有模型失败: {e}")
return []
@staticmethod
def _format_models(raw_models: list) -> list:
"""
格式化模型列表
"""
formatted = []
for model in raw_models:
formatted.append({
"id": model.get("id"),
"provider_id": model.get("provider_id"),
"model_name": model.get("model_name"),
"created_at": model.get("created_at", None), # 如果有created_at字段
})
return formatted
@staticmethod
def get_all_models_by_id(provider_id: str, verbose: bool = False):
try:
provider = ProviderService.get_provider_by_id(provider_id)
models = ModelService.get_model_list(provider["id"], verbose=verbose)
model_list={
"models": models
}
return model_list
except Exception as e:
print(f"[{provider_id}] 获取模型失败: {e}")
return []
@staticmethod
def connect_test(api_key: str, base_url: str) -> bool:
try:
return OpenAICompatibleProvider.test_connection(api_key=api_key, base_url=base_url)
except Exception as e:
print(f"连接测试失败:{e}")
return False
@staticmethod
def add_new_model(provider_id: int, model_name: str) -> bool:
try:
# 先查供应商是否存在
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
print(f"供应商ID {provider_id} 不存在,无法添加模型")
return False
# 插入模型
insert_model(provider_id=provider_id, model_name=model_name)
print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}")
return True
except Exception as e:
print(f"添加模型失败: {e}")
return False
if __name__ == '__main__':
print(ModelService.get_model_list(1))
# 单个 Provider 测试
print(ModelService.get_model_list(1, verbose=True))
# 所有 Provider 模型测试
# print(ModelService.get_all_models(verbose=True))

View File

@@ -1,5 +1,9 @@
import json
from dataclasses import asdict
from app.enmus.task_status_enums import TaskStatus
import os
from typing import Union
from typing import Union, Optional
from pydantic import HttpUrl
@@ -10,13 +14,17 @@ from app.downloaders.douyin_downloader import DouyinDownloader
from app.downloaders.youtube_downloader import YoutubeDownloader
from app.gpt.base import GPT
from app.gpt.deepseek_gpt import DeepSeekGPT
from app.gpt.gpt_factory import GPTFactory
from app.gpt.openai_gpt import OpenaiGPT
from app.gpt.qwen_gpt import QwenGPT
from app.models.gpt_model import GPTSource
from app.models.model_config import ModelConfig
from app.models.notes_model import NoteResult
from app.models.notes_model import AudioDownloadResult
from app.enmus.note_enums import DownloadQuality
from app.models.transcriber_model import TranscriptResult
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
from app.transcriber.whisper import WhisperTranscriber
@@ -29,6 +37,8 @@ from app.utils.video_helper import generate_screenshot
# from app.services.gpt import summarize_text
from dotenv import load_dotenv
from app.utils.logger import get_logger
from events import transcription_finished
logger = get_logger(__name__)
load_dotenv()
BACKEND_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
@@ -37,7 +47,7 @@ output_dir = os.getenv('OUT_DIR')
image_base_url = os.getenv('IMAGE_BASE_URL')
logger.info("starting up")
NOTE_OUTPUT_DIR = "note_results"
class NoteGenerator:
def __init__(self):
@@ -45,26 +55,39 @@ class NoteGenerator:
self.device: Union[str, None] = None
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE','fast-whisper')
self.transcriber = self.get_transcriber()
# TODO 需要更换为可调节
self.provider = os.getenv('MODEl_PROVIDER','openai')
self.video_path = None
logger.info("初始化NoteGenerator")
import logging
def get_gpt(self) -> GPT:
if self.provider == 'openai':
logger.info("使用OpenAI")
return OpenaiGPT()
elif self.provider == 'deepSeek':
logger.info("使用DeepSeek")
return DeepSeekGPT()
elif self.provider == 'qwen':
logger.info("使用Qwen")
return QwenGPT()
else:
logger.warning("不支持的AI提供商")
raise ValueError(f"不支持的AI提供商{self.provider}")
logger = logging.getLogger(__name__)
@staticmethod
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True)
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json")
content = {"status": status.value if isinstance(status, TaskStatus) else status}
if message:
content["message"] = message
with open(path, "w", encoding="utf-8") as f:
json.dump(content, f, ensure_ascii=False, indent=2)
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
gpt = GPTFactory().from_config(
ModelConfig(
api_key=provider.get('api_key'),
base_url=provider.get('base_url'),
model_name=model_name,
provider=provider.get('type'),
name=provider.get('name')
)
)
return gpt
def get_downloader(self, platform: str) -> Downloader:
if platform == "bilibili":
@@ -98,7 +121,7 @@ class NoteGenerator:
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
output_dir: str) -> str:
output_dir: str,_format:list) -> str:
"""
扫描 markdown 中的 *Screenshot-xx:xx生成截图并插入 markdown 图片
:param markdown:
@@ -145,62 +168,143 @@ class NoteGenerator:
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Union[str, None] = None,
model_name: str = None,
provider_id: str = None,
link: bool = False,
screenshot: bool = False,
_format: list = None,
style: str = None,
extras: str = None,
path: Union[str, None] = None
) -> NoteResult:
logger.info(f"开始解析并生成笔记")
# 1. 选择下载器
downloader = self.get_downloader(platform)
gpt = self.get_gpt()
logger.info(f'使用{downloader.__class__.__name__}下载器\n'
f'使用{gpt.__class__.__name__}GPT\n'
f'视频地址:{video_url}')
if screenshot:
try:
logger.info(f"🎯 开始解析并生成笔记task_id={task_id}")
self.update_task_status(task_id, TaskStatus.PARSING)
_path=''
downloader = self.get_downloader(platform)
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
video_path = downloader.download_video(video_url)
self.video_path = video_path
print(video_path)
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
# 2. 下载音频
audio: AudioDownloadResult = downloader.download(
video_url=video_url,
quality=quality,
output_dir=path,
need_video=screenshot
# -------- 1. 下载音频 --------
try:
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
if os.path.exists(audio_cache_path):
logger.info(f"检测到已有音频缓存直接读取task_id={task_id}")
with open(audio_cache_path, "r", encoding="utf-8") as f:
audio_data = json.load(f)
audio = AudioDownloadResult(**audio_data)
else:
if 'screenshot' in _format:
video_path = downloader.download_video(video_url)
self.video_path = video_path
logger.info(f"成功下载视频文件: {video_path}")
screenshot= 'screenshot' in _format
audio: AudioDownloadResult = downloader.download(
video_url=video_url,
quality=quality,
output_dir=path,
need_video=screenshot
)
_path=audio.raw_info.get('path')
with open(audio_cache_path, "w", encoding="utf-8") as f:
json.dump(audio.__dict__, f, ensure_ascii=False, indent=2)
logger.info(f"音频下载并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"❌ 下载音频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise e
# -------- 2. 转写文字 --------
try:
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
if os.path.exists(transcript_cache_path):
logger.info(f"检测到已有转写缓存直接读取task_id={task_id}")
with open(transcript_cache_path, "r", encoding="utf-8") as f:
transcript_data = json.load(f)
transcript = TranscriptResult(
language=transcript_data["language"],
full_text=transcript_data["full_text"],
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
)
else:
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
logger.info(f"文字转写并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"❌ 转写文字失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
raise e
# -------- 3. 总结内容 --------
try:
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
if os.path.exists(markdown_cache_path):
logger.info(f"检测到已有总结缓存直接读取task_id={task_id}")
with open(markdown_cache_path, "r", encoding="utf-8") as f:
markdown = f.read()
else:
source = GPTSource(
title=audio.title,
segment=transcript.segments,
tags=audio.raw_info.get('tags'),
screenshot=screenshot,
link=link,
_format=_format,
style=style,
extras=extras
)
markdown: str = gpt.summarize(source)
with open(markdown_cache_path, "w", encoding="utf-8") as f:
f.write(markdown)
logger.info(f"GPT总结并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"❌ 总结内容失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
raise e
# -------- 4. 插入截图 --------
if _format and 'screenshot' in _format:
try:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir,_format)
except Exception as e:
logger.warning(f"⚠️ 插入截图失败跳过处理task_id={task_id},错误信息:{e}")
if _format and 'link' in _format:
try:
markdown = replace_content_markers(markdown, video_id=audio.video_id,platform=platform)
except Exception as e:
logger.warning(f"⚠️ 插入链接失败跳过处理task_id={task_id},错误信息:{e}")
# 注意:截图失败不终止整体流程
# -------- 5. 保存数据库记录 --------
self.update_task_status(task_id, TaskStatus.SAVING)
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
# -------- 6. 完成 --------
self.update_task_status(task_id, TaskStatus.SUCCESS)
logger.info(f"✅ 笔记生成成功task_id={task_id}")
transcription_finished.send({
"file_path": audio.file_path,
})
return NoteResult(
markdown=markdown,
transcript=transcript,
audio_meta=audio
)
except Exception as e:
logger.error(f"❌ 笔记生成流程异常终止task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
raise f'❌ 笔记生成流程异常终止task_id={task_id},错误信息:{e}'
)
logger.info(f"下载音频成功,文件路径:{audio.file_path}")
# 3. Whisper 转写
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
logger.info(f"Whisper 转写成功,转写结果:{transcript.full_text}")
# 4. GPT 总结
source = GPTSource(
title=audio.title,
segment=transcript.segments,
tags=audio.raw_info.get('tags'),
screenshot=screenshot,
link=link
)
logger.info(f"GPT 总结完成,总结结果:{source}")
markdown: str = gpt.summarize(source)
print("markdown结果", markdown)
markdown = replace_content_markers(markdown=markdown, video_id=audio.video_id, platform=platform)
if self.video_path:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir)
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
# 5. 返回结构体
return NoteResult(
markdown=markdown,
transcript=transcript,
audio_meta=audio
)

View File

@@ -1,3 +1,5 @@
from kombu import uuid
from app.db.provider_dao import (
insert_provider,
init_provider_table,
@@ -5,50 +7,65 @@ from app.db.provider_dao import (
get_provider_by_name,
get_provider_by_id,
update_provider,
delete_provider,
delete_provider, get_enabled_providers,
)
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
class ProviderService:
@staticmethod
def serialize_provider(row: tuple) -> dict:
if not row:
return None
return {
"id": row[0],
"name": row[1],
"logo": row[2],
"type": row[3],
"api_key": row[4],
"base_url": row[5],
"enabled": row[6],
"created_at": row[7],
}
@staticmethod
def add_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
return insert_provider(name, api_key, base_url, logo, type_)
def add_provider( name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1):
try:
id = uuid().lower()
logo='custom'
return insert_provider(id, name, api_key, base_url, logo, type_, enabled)
except Exception as e:
print('创建模式失败',e)
@staticmethod
def get_all_providers():
provider_list = []
provider = get_all_providers()
for i in provider:
provider_list.append({
"id": i[0],
"name": i[1],
"logo": i[2],
"type": i[3], # ✅ 加上类型
"api_key": i[4],
"base_url": i[5],
})
return provider_list
rows = get_all_providers()
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
@staticmethod
def get_provider_by_name(name: str):
return get_provider_by_name(name)
row = get_provider_by_name(name)
return ProviderService.serialize_provider(row)
@staticmethod
def get_provider_by_id(id: int):
return get_provider_by_id(id)
def get_provider_by_id(id: str): # 已改为 str 类型
row = get_provider_by_id(id)
return ProviderService.serialize_provider(row)
# all_models.extend(provider['models'])
@staticmethod
def update_provider(
id: int,
name: str,
api_key: str,
base_url: str,
logo: str,
type_: str
):
return update_provider(id, name, api_key, base_url, logo, type_)
def update_provider(id: str, data: dict):
try:
# 过滤掉空值
filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'}
print('更新模型供应商',filtered_data)
return update_provider(id, **filtered_data)
except Exception as e:
print('更新模型供应商失败:',e)
@staticmethod
def delete_provider(id: int):
return delete_provider(id)
def delete_provider(id: str):
return delete_provider(id)