mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-06 00:01:54 +08:00
feat: 新增模型管理和供应商配置功能
### v1.1.0 - #### Added - 新增 AI 笔记风格选择 - 新增 AI 笔记返回格式选择 - 添加 AI 自定义笔记备注 Prompt - 添加任务失败重试 - 添加全局设置页,可在设置页进行模型设置 - #### Optimize - 优化前端样式,优化用户体验 - 增加生成中间产物,可用于失败后加快生成速度 - #### Fix - 修复视频截图视频过早删除错误
This commit is contained in:
@@ -1,23 +1,109 @@
|
||||
from app.db.model_dao import insert_model, get_all_models
|
||||
from app.db.provider_dao import get_enabled_providers
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.provider import ProviderService
|
||||
|
||||
|
||||
class ModelService:
|
||||
|
||||
@staticmethod
|
||||
def get_model_list(provider_id: int):
|
||||
provider=ProviderService.get_provider_by_id(provider_id)
|
||||
def _build_model_config(provider: dict) -> ModelConfig:
|
||||
return ModelConfig(
|
||||
api_key=provider["api_key"],
|
||||
base_url=provider["base_url"],
|
||||
provider=provider["name"],
|
||||
model_name='',
|
||||
name=provider["name"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_list(provider_id: int, verbose: bool = False):
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
return []
|
||||
config=ModelConfig(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
provider=provider.name,
|
||||
model_name='',
|
||||
name=provider.name,
|
||||
)
|
||||
GPT=GPTFactory().from_config(config)
|
||||
return GPT.list_models()
|
||||
|
||||
try:
|
||||
config = ModelService._build_model_config(provider)
|
||||
gpt = GPTFactory().from_config(config)
|
||||
models = gpt.list_models()
|
||||
if verbose:
|
||||
print(f"[{provider['name']}] 模型列表: {models}")
|
||||
return models
|
||||
except Exception as e:
|
||||
print(f"[{provider['name']}] 获取模型失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_all_models(verbose: bool = False):
|
||||
try:
|
||||
raw_models = get_all_models()
|
||||
if verbose:
|
||||
print(f"所有模型列表: {raw_models}")
|
||||
return ModelService._format_models(raw_models)
|
||||
except Exception as e:
|
||||
print(f"获取所有模型失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_models(raw_models: list) -> list:
|
||||
"""
|
||||
格式化模型列表
|
||||
"""
|
||||
formatted = []
|
||||
for model in raw_models:
|
||||
formatted.append({
|
||||
"id": model.get("id"),
|
||||
"provider_id": model.get("provider_id"),
|
||||
"model_name": model.get("model_name"),
|
||||
"created_at": model.get("created_at", None), # 如果有created_at字段
|
||||
})
|
||||
return formatted
|
||||
@staticmethod
|
||||
def get_all_models_by_id(provider_id: str, verbose: bool = False):
|
||||
try:
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
|
||||
models = ModelService.get_model_list(provider["id"], verbose=verbose)
|
||||
|
||||
model_list={
|
||||
|
||||
"models": models
|
||||
}
|
||||
|
||||
return model_list
|
||||
except Exception as e:
|
||||
print(f"[{provider_id}] 获取模型失败: {e}")
|
||||
return []
|
||||
@staticmethod
|
||||
def connect_test(api_key: str, base_url: str) -> bool:
|
||||
try:
|
||||
return OpenAICompatibleProvider.test_connection(api_key=api_key, base_url=base_url)
|
||||
except Exception as e:
|
||||
print(f"连接测试失败:{e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def add_new_model(provider_id: int, model_name: str) -> bool:
|
||||
try:
|
||||
# 先查供应商是否存在
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
print(f"供应商ID {provider_id} 不存在,无法添加模型")
|
||||
return False
|
||||
|
||||
# 插入模型
|
||||
insert_model(provider_id=provider_id, model_name=model_name)
|
||||
print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"添加模型失败: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(ModelService.get_model_list(1))
|
||||
# 单个 Provider 测试
|
||||
print(ModelService.get_model_list(1, verbose=True))
|
||||
|
||||
# 所有 Provider 模型测试
|
||||
# print(ModelService.get_all_models(verbose=True))
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
|
||||
from app.enmus.task_status_enums import TaskStatus
|
||||
import os
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
from pydantic import HttpUrl
|
||||
|
||||
@@ -10,13 +14,17 @@ from app.downloaders.douyin_downloader import DouyinDownloader
|
||||
from app.downloaders.youtube_downloader import YoutubeDownloader
|
||||
from app.gpt.base import GPT
|
||||
from app.gpt.deepseek_gpt import DeepSeekGPT
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.gpt.openai_gpt import OpenaiGPT
|
||||
from app.gpt.qwen_gpt import QwenGPT
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.models.notes_model import NoteResult
|
||||
from app.models.notes_model import AudioDownloadResult
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.models.transcriber_model import TranscriptResult
|
||||
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
|
||||
|
||||
from app.services.provider import ProviderService
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
@@ -29,6 +37,8 @@ from app.utils.video_helper import generate_screenshot
|
||||
# from app.services.gpt import summarize_text
|
||||
from dotenv import load_dotenv
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
logger = get_logger(__name__)
|
||||
load_dotenv()
|
||||
BACKEND_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
|
||||
@@ -37,7 +47,7 @@ output_dir = os.getenv('OUT_DIR')
|
||||
image_base_url = os.getenv('IMAGE_BASE_URL')
|
||||
logger.info("starting up")
|
||||
|
||||
|
||||
NOTE_OUTPUT_DIR = "note_results"
|
||||
|
||||
class NoteGenerator:
|
||||
def __init__(self):
|
||||
@@ -45,26 +55,39 @@ class NoteGenerator:
|
||||
self.device: Union[str, None] = None
|
||||
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE','fast-whisper')
|
||||
self.transcriber = self.get_transcriber()
|
||||
# TODO 需要更换为可调节
|
||||
|
||||
self.provider = os.getenv('MODEl_PROVIDER','openai')
|
||||
self.video_path = None
|
||||
logger.info("初始化NoteGenerator")
|
||||
|
||||
import logging
|
||||
|
||||
def get_gpt(self) -> GPT:
|
||||
if self.provider == 'openai':
|
||||
logger.info("使用OpenAI")
|
||||
return OpenaiGPT()
|
||||
elif self.provider == 'deepSeek':
|
||||
logger.info("使用DeepSeek")
|
||||
return DeepSeekGPT()
|
||||
elif self.provider == 'qwen':
|
||||
logger.info("使用Qwen")
|
||||
return QwenGPT()
|
||||
else:
|
||||
logger.warning("不支持的AI提供商")
|
||||
raise ValueError(f"不支持的AI提供商:{self.provider}")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@staticmethod
|
||||
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
|
||||
os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True)
|
||||
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json")
|
||||
content = {"status": status.value if isinstance(status, TaskStatus) else status}
|
||||
if message:
|
||||
content["message"] = message
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(content, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
|
||||
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
|
||||
|
||||
gpt = GPTFactory().from_config(
|
||||
ModelConfig(
|
||||
api_key=provider.get('api_key'),
|
||||
base_url=provider.get('base_url'),
|
||||
model_name=model_name,
|
||||
provider=provider.get('type'),
|
||||
name=provider.get('name')
|
||||
)
|
||||
)
|
||||
return gpt
|
||||
|
||||
def get_downloader(self, platform: str) -> Downloader:
|
||||
if platform == "bilibili":
|
||||
@@ -98,7 +121,7 @@ class NoteGenerator:
|
||||
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
|
||||
|
||||
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
|
||||
output_dir: str) -> str:
|
||||
output_dir: str,_format:list) -> str:
|
||||
"""
|
||||
扫描 markdown 中的 *Screenshot-xx:xx,生成截图并插入 markdown 图片
|
||||
:param markdown:
|
||||
@@ -145,62 +168,143 @@ class NoteGenerator:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
||||
video_url: Union[str, HttpUrl],
|
||||
platform: str,
|
||||
quality: DownloadQuality = DownloadQuality.medium,
|
||||
task_id: Union[str, None] = None,
|
||||
model_name: str = None,
|
||||
provider_id: str = None,
|
||||
link: bool = False,
|
||||
screenshot: bool = False,
|
||||
_format: list = None,
|
||||
style: str = None,
|
||||
extras: str = None,
|
||||
path: Union[str, None] = None
|
||||
|
||||
) -> NoteResult:
|
||||
logger.info(f"开始解析并生成笔记")
|
||||
# 1. 选择下载器
|
||||
downloader = self.get_downloader(platform)
|
||||
gpt = self.get_gpt()
|
||||
logger.info(f'使用{downloader.__class__.__name__}下载器\n'
|
||||
f'使用{gpt.__class__.__name__}GPT\n'
|
||||
f'视频地址:{video_url}')
|
||||
if screenshot:
|
||||
try:
|
||||
logger.info(f"🎯 开始解析并生成笔记,task_id={task_id}")
|
||||
self.update_task_status(task_id, TaskStatus.PARSING)
|
||||
_path=''
|
||||
downloader = self.get_downloader(platform)
|
||||
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
|
||||
|
||||
video_path = downloader.download_video(video_url)
|
||||
self.video_path = video_path
|
||||
print(video_path)
|
||||
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
|
||||
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
|
||||
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
|
||||
|
||||
# 2. 下载音频
|
||||
audio: AudioDownloadResult = downloader.download(
|
||||
video_url=video_url,
|
||||
quality=quality,
|
||||
output_dir=path,
|
||||
need_video=screenshot
|
||||
# -------- 1. 下载音频 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
|
||||
if os.path.exists(audio_cache_path):
|
||||
logger.info(f"检测到已有音频缓存,直接读取,task_id={task_id}")
|
||||
with open(audio_cache_path, "r", encoding="utf-8") as f:
|
||||
audio_data = json.load(f)
|
||||
audio = AudioDownloadResult(**audio_data)
|
||||
else:
|
||||
if 'screenshot' in _format:
|
||||
video_path = downloader.download_video(video_url)
|
||||
self.video_path = video_path
|
||||
logger.info(f"成功下载视频文件: {video_path}")
|
||||
screenshot= 'screenshot' in _format
|
||||
audio: AudioDownloadResult = downloader.download(
|
||||
video_url=video_url,
|
||||
quality=quality,
|
||||
output_dir=path,
|
||||
need_video=screenshot
|
||||
)
|
||||
_path=audio.raw_info.get('path')
|
||||
with open(audio_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(audio.__dict__, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"音频下载并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 下载音频失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
||||
raise e
|
||||
|
||||
# -------- 2. 转写文字 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
|
||||
if os.path.exists(transcript_cache_path):
|
||||
logger.info(f"检测到已有转写缓存,直接读取,task_id={task_id}")
|
||||
with open(transcript_cache_path, "r", encoding="utf-8") as f:
|
||||
transcript_data = json.load(f)
|
||||
transcript = TranscriptResult(
|
||||
language=transcript_data["language"],
|
||||
full_text=transcript_data["full_text"],
|
||||
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
|
||||
)
|
||||
else:
|
||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
||||
with open(transcript_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"文字转写并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 转写文字失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
|
||||
raise e
|
||||
|
||||
# -------- 3. 总结内容 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
|
||||
if os.path.exists(markdown_cache_path):
|
||||
logger.info(f"检测到已有总结缓存,直接读取,task_id={task_id}")
|
||||
with open(markdown_cache_path, "r", encoding="utf-8") as f:
|
||||
markdown = f.read()
|
||||
else:
|
||||
source = GPTSource(
|
||||
title=audio.title,
|
||||
segment=transcript.segments,
|
||||
tags=audio.raw_info.get('tags'),
|
||||
screenshot=screenshot,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras
|
||||
)
|
||||
|
||||
markdown: str = gpt.summarize(source)
|
||||
with open(markdown_cache_path, "w", encoding="utf-8") as f:
|
||||
f.write(markdown)
|
||||
logger.info(f"GPT总结并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 总结内容失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
|
||||
raise e
|
||||
|
||||
# -------- 4. 插入截图 --------
|
||||
if _format and 'screenshot' in _format:
|
||||
try:
|
||||
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir,_format)
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 插入截图失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
if _format and 'link' in _format:
|
||||
try:
|
||||
markdown = replace_content_markers(markdown, video_id=audio.video_id,platform=platform)
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 插入链接失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
# 注意:截图失败不终止整体流程
|
||||
|
||||
# -------- 5. 保存数据库记录 --------
|
||||
self.update_task_status(task_id, TaskStatus.SAVING)
|
||||
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
|
||||
|
||||
# -------- 6. 完成 --------
|
||||
self.update_task_status(task_id, TaskStatus.SUCCESS)
|
||||
logger.info(f"✅ 笔记生成成功,task_id={task_id}")
|
||||
transcription_finished.send({
|
||||
"file_path": audio.file_path,
|
||||
})
|
||||
return NoteResult(
|
||||
markdown=markdown,
|
||||
transcript=transcript,
|
||||
audio_meta=audio
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 笔记生成流程异常终止,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
|
||||
raise f'❌ 笔记生成流程异常终止,task_id={task_id},错误信息:{e}'
|
||||
|
||||
)
|
||||
logger.info(f"下载音频成功,文件路径:{audio.file_path}")
|
||||
# 3. Whisper 转写
|
||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
||||
logger.info(f"Whisper 转写成功,转写结果:{transcript.full_text}")
|
||||
# 4. GPT 总结
|
||||
source = GPTSource(
|
||||
title=audio.title,
|
||||
segment=transcript.segments,
|
||||
tags=audio.raw_info.get('tags'),
|
||||
screenshot=screenshot,
|
||||
link=link
|
||||
)
|
||||
logger.info(f"GPT 总结完成,总结结果:{source}")
|
||||
markdown: str = gpt.summarize(source)
|
||||
print("markdown结果", markdown)
|
||||
|
||||
markdown = replace_content_markers(markdown=markdown, video_id=audio.video_id, platform=platform)
|
||||
if self.video_path:
|
||||
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir)
|
||||
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
|
||||
# 5. 返回结构体
|
||||
return NoteResult(
|
||||
markdown=markdown,
|
||||
transcript=transcript,
|
||||
audio_meta=audio
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from kombu import uuid
|
||||
|
||||
from app.db.provider_dao import (
|
||||
insert_provider,
|
||||
init_provider_table,
|
||||
@@ -5,50 +7,65 @@ from app.db.provider_dao import (
|
||||
get_provider_by_name,
|
||||
get_provider_by_id,
|
||||
update_provider,
|
||||
delete_provider,
|
||||
delete_provider, get_enabled_providers,
|
||||
)
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.models.model_config import ModelConfig
|
||||
|
||||
|
||||
class ProviderService:
|
||||
@staticmethod
|
||||
def serialize_provider(row: tuple) -> dict:
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"logo": row[2],
|
||||
"type": row[3],
|
||||
"api_key": row[4],
|
||||
"base_url": row[5],
|
||||
"enabled": row[6],
|
||||
"created_at": row[7],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def add_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
|
||||
return insert_provider(name, api_key, base_url, logo, type_)
|
||||
def add_provider( name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1):
|
||||
try:
|
||||
id = uuid().lower()
|
||||
logo='custom'
|
||||
return insert_provider(id, name, api_key, base_url, logo, type_, enabled)
|
||||
except Exception as e:
|
||||
print('创建模式失败',e)
|
||||
|
||||
@staticmethod
|
||||
def get_all_providers():
|
||||
provider_list = []
|
||||
provider = get_all_providers()
|
||||
|
||||
for i in provider:
|
||||
provider_list.append({
|
||||
"id": i[0],
|
||||
"name": i[1],
|
||||
"logo": i[2],
|
||||
"type": i[3], # ✅ 加上类型
|
||||
"api_key": i[4],
|
||||
"base_url": i[5],
|
||||
})
|
||||
return provider_list
|
||||
rows = get_all_providers()
|
||||
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
|
||||
|
||||
@staticmethod
|
||||
def get_provider_by_name(name: str):
|
||||
return get_provider_by_name(name)
|
||||
row = get_provider_by_name(name)
|
||||
return ProviderService.serialize_provider(row)
|
||||
|
||||
@staticmethod
|
||||
def get_provider_by_id(id: int):
|
||||
return get_provider_by_id(id)
|
||||
def get_provider_by_id(id: str): # 已改为 str 类型
|
||||
row = get_provider_by_id(id)
|
||||
return ProviderService.serialize_provider(row)
|
||||
|
||||
# all_models.extend(provider['models'])
|
||||
|
||||
@staticmethod
|
||||
def update_provider(
|
||||
id: int,
|
||||
name: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
logo: str,
|
||||
type_: str
|
||||
):
|
||||
return update_provider(id, name, api_key, base_url, logo, type_)
|
||||
def update_provider(id: str, data: dict):
|
||||
try:
|
||||
# 过滤掉空值
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'}
|
||||
print('更新模型供应商',filtered_data)
|
||||
return update_provider(id, **filtered_data)
|
||||
|
||||
except Exception as e:
|
||||
print('更新模型供应商失败:',e)
|
||||
|
||||
@staticmethod
|
||||
def delete_provider(id: int):
|
||||
return delete_provider(id)
|
||||
def delete_provider(id: str):
|
||||
return delete_provider(id)
|
||||
|
||||
Reference in New Issue
Block a user