mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
- 新增自定义异常类 BizException、NoteError 和 ProviderError - 优化了模型管理相关的逻辑,包括加载、删除和测试连接等功能 - 改进了 Douyin 下载器的错误处理 - 调整了任务重试逻辑和笔记生成的异常处理- 更新了相关组件和页面以适应新的异常处理机制
247 lines
8.3 KiB
Python
247 lines
8.3 KiB
Python
# app/routers/note.py
|
||
import json
|
||
import os
|
||
import uuid
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
from urllib.parse import urlparse
|
||
|
||
from fastapi import APIRouter, HTTPException, BackgroundTasks, UploadFile, File
|
||
from pydantic import BaseModel, validator, field_validator
|
||
from dataclasses import asdict
|
||
|
||
from app.db.video_task_dao import get_task_by_video
|
||
from app.enmus.exception import NoteErrorEnum
|
||
from app.enmus.note_enums import DownloadQuality
|
||
from app.exceptions.note import NoteError
|
||
from app.services.note import NoteGenerator, logger
|
||
from app.utils.response import ResponseWrapper as R
|
||
from app.utils.url_parser import extract_video_id
|
||
from app.validators.video_url_validator import is_supported_video_url
|
||
from fastapi import APIRouter, Request, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
import httpx
|
||
from app.enmus.task_status_enums import TaskStatus
|
||
|
||
# from app.services.downloader import download_raw_audio
|
||
# from app.services.whisperer import transcribe_audio
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class RecordRequest(BaseModel):
|
||
video_id: str
|
||
platform: str
|
||
|
||
|
||
class VideoRequest(BaseModel):
|
||
video_url: str
|
||
platform: str
|
||
quality: DownloadQuality
|
||
screenshot: Optional[bool] = False
|
||
link: Optional[bool] = False
|
||
model_name: str
|
||
provider_id: str
|
||
task_id: Optional[str] = None
|
||
format: Optional[list] = []
|
||
style: str = None
|
||
extras: Optional[str]=None
|
||
video_understanding: Optional[bool] = False
|
||
video_interval: Optional[int] = 0
|
||
grid_size: Optional[list] = []
|
||
|
||
@field_validator("video_url")
|
||
def validate_supported_url(cls, v):
|
||
url = str(v)
|
||
parsed = urlparse(url)
|
||
if parsed.scheme in ("http", "https"):
|
||
# 是网络链接,继续用原有平台校验
|
||
if not is_supported_video_url(url):
|
||
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
|
||
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
|
||
|
||
return v
|
||
|
||
|
||
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
|
||
UPLOAD_DIR = "uploads"
|
||
|
||
|
||
def save_note_to_file(task_id: str, note):
|
||
os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True)
|
||
with open(os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json"), "w", encoding="utf-8") as f:
|
||
json.dump(asdict(note), f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def run_note_task(task_id: str, video_url: str, platform: str, quality: DownloadQuality,
|
||
link: bool = False, screenshot: bool = False, model_name: str = None, provider_id: str = None,
|
||
_format: list = None, style: str = None, extras: str = None, video_understanding: bool = False,
|
||
video_interval=0, grid_size=[]
|
||
):
|
||
|
||
if not model_name or not provider_id:
|
||
raise HTTPException(status_code=400, detail="请选择模型和提供者")
|
||
|
||
note = NoteGenerator().generate(
|
||
video_url=video_url,
|
||
platform=platform,
|
||
quality=quality,
|
||
task_id=task_id,
|
||
model_name=model_name,
|
||
provider_id=provider_id,
|
||
link=link,
|
||
_format=_format,
|
||
style=style,
|
||
extras=extras,
|
||
screenshot=screenshot
|
||
, video_understanding=video_understanding,
|
||
video_interval=video_interval,
|
||
grid_size=grid_size
|
||
)
|
||
logger.info(f"Note generated: {task_id}")
|
||
if not note or not note.markdown:
|
||
logger.warning(f"任务 {task_id} 执行失败,跳过保存")
|
||
return
|
||
save_note_to_file(task_id, note)
|
||
|
||
|
||
|
||
@router.post('/delete_task')
|
||
def delete_task(data: RecordRequest):
|
||
try:
|
||
|
||
NoteGenerator().delete_note(video_id=data.video_id, platform=data.platform)
|
||
return R.success(msg='删除成功')
|
||
except Exception as e:
|
||
return R.error(msg=e)
|
||
|
||
|
||
@router.post("/upload")
|
||
async def upload(file: UploadFile = File(...)):
|
||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||
file_location = os.path.join(UPLOAD_DIR, file.filename)
|
||
|
||
with open(file_location, "wb+") as f:
|
||
f.write(await file.read())
|
||
|
||
# 假设你静态目录挂载了 /uploads
|
||
return R.success({"url": f"/uploads/{file.filename}"})
|
||
|
||
|
||
@router.post("/generate_note")
|
||
def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
|
||
try:
|
||
|
||
video_id = extract_video_id(data.video_url, data.platform)
|
||
# if not video_id:
|
||
# raise HTTPException(status_code=400, detail="无法提取视频 ID")
|
||
# existing = get_task_by_video(video_id, data.platform)
|
||
# if existing:
|
||
# return R.error(
|
||
# msg='笔记已生成,请勿重复发起',
|
||
#
|
||
# )
|
||
if data.task_id:
|
||
# 如果传了task_id,说明是重试!
|
||
task_id = data.task_id
|
||
# 更新之前的状态
|
||
NoteGenerator.update_task_status(task_id, TaskStatus.PENDING)
|
||
logger.info(f"重试模式,复用已有 task_id={task_id}")
|
||
else:
|
||
# 正常新建任务
|
||
task_id = str(uuid.uuid4())
|
||
|
||
background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality, data.link,
|
||
data.screenshot, data.model_name, data.provider_id, data.format, data.style,
|
||
data.extras, data.video_understanding, data.video_interval, data.grid_size)
|
||
return R.success({"task_id": task_id})
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/task_status/{task_id}")
|
||
def get_task_status(task_id: str):
|
||
status_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json")
|
||
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
|
||
|
||
# 优先读状态文件
|
||
if os.path.exists(status_path):
|
||
with open(status_path, "r", encoding="utf-8") as f:
|
||
status_content = json.load(f)
|
||
|
||
status = status_content.get("status")
|
||
message = status_content.get("message", "")
|
||
|
||
if status == TaskStatus.SUCCESS.value:
|
||
# 成功状态的话,继续读取最终笔记内容
|
||
if os.path.exists(result_path):
|
||
with open(result_path, "r", encoding="utf-8") as rf:
|
||
result_content = json.load(rf)
|
||
return R.success({
|
||
"status": status,
|
||
"result": result_content,
|
||
"message": message,
|
||
"task_id": task_id
|
||
})
|
||
else:
|
||
# 理论上不会出现,保险处理
|
||
return R.success({
|
||
"status": TaskStatus.PENDING.value,
|
||
"message": "任务完成,但结果文件未找到",
|
||
"task_id": task_id
|
||
})
|
||
|
||
if status == TaskStatus.FAILED.value:
|
||
return R.error(message or "任务失败", code=500)
|
||
|
||
# 处理中状态
|
||
return R.success({
|
||
"status": status,
|
||
"message": message,
|
||
"task_id": task_id
|
||
})
|
||
|
||
# 没有状态文件,但有结果
|
||
if os.path.exists(result_path):
|
||
with open(result_path, "r", encoding="utf-8") as f:
|
||
result_content = json.load(f)
|
||
return R.success({
|
||
"status": TaskStatus.SUCCESS.value,
|
||
"result": result_content,
|
||
"task_id": task_id
|
||
})
|
||
|
||
# 什么都没有,默认PENDING
|
||
return R.success({
|
||
"status": TaskStatus.PENDING.value,
|
||
"message": "任务排队中",
|
||
"task_id": task_id
|
||
})
|
||
|
||
|
||
@router.get("/image_proxy")
|
||
async def image_proxy(request: Request, url: str):
|
||
headers = {
|
||
"Referer": "https://www.bilibili.com/",
|
||
"User-Agent": request.headers.get("User-Agent", ""),
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
resp = await client.get(url, headers=headers)
|
||
|
||
if resp.status_code != 200:
|
||
raise HTTPException(status_code=resp.status_code, detail="图片获取失败")
|
||
|
||
content_type = resp.headers.get("Content-Type", "image/jpeg")
|
||
return StreamingResponse(
|
||
resp.aiter_bytes(),
|
||
media_type=content_type,
|
||
headers={
|
||
"Cache-Control": "public, max-age=86400", # ✅ 缓存一天
|
||
"Content-Type": content_type,
|
||
}
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|