Files
BiliNote/backend/app/gpt/universal_gpt.py
CyanAutumn d9a7b89e7d 🐞 fix: 增加错误之后对已解析段落的缓存功能,再次重试时不再重头开始
解析长视频时,当附件大小过大时不再调用后进行报错,而是将附件进行分批次发送

在每篇笔记开头默认增加地址来源链接,对模糊处可溯源
2026-02-12 18:28:11 +08:00

308 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from app.gpt.base import GPT
from app.gpt.prompt_builder import generate_base_prompt
from app.models.gpt_model import GPTSource
import os
import hashlib
import json
import time
from datetime import datetime, timezone
from pathlib import Path
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK, MERGE_PROMPT
from app.gpt.utils import fix_markdown
from app.gpt.request_chunker import RequestChunker
from app.models.transcriber_model import TranscriptSegment
from datetime import timedelta
from typing import List
class UniversalGPT(GPT):
def __init__(self, client, model: str, temperature: float = 0.7):
self.client = client
self.model = model
self.temperature = temperature
self.screenshot = False
self.link = False
self.max_request_bytes = int(os.getenv("OPENAI_MAX_REQUEST_BYTES", str(45 * 1024 * 1024)))
self.checkpoint_dir = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
def _format_time(self, seconds: float) -> str:
return str(timedelta(seconds=int(seconds)))[2:]
def _build_segment_text(self, segments: List[TranscriptSegment]) -> str:
return "\n".join(
f"{self._format_time(seg.start)} - {seg.text.strip()}"
for seg in segments
)
def ensure_segments_type(self, segments) -> List[TranscriptSegment]:
return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments]
def create_messages(self, segments: List[TranscriptSegment], **kwargs):
content_text = generate_base_prompt(
title=kwargs.get('title'),
segment_text=self._build_segment_text(segments),
tags=kwargs.get('tags'),
_format=kwargs.get('_format'),
style=kwargs.get('style'),
extras=kwargs.get('extras'),
)
# ⛳ 组装 content 数组,支持 text + image_url 混合
content: List[dict] = [{"type": "text", "text": content_text}]
video_img_urls = kwargs.get('video_img_urls', [])
for url in video_img_urls:
content.append({
"type": "image_url",
"image_url": {
"url": url,
"detail": "auto"
}
})
# 正确格式:整体包在一个 message 里role + content array
messages = [{
"role": "user",
"content": content
}]
return messages
def list_models(self):
return self.client.models.list()
def _estimate_messages_bytes(self, messages: list) -> int:
import json
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
def _build_merge_messages(self, partials: list) -> list:
merge_text = MERGE_PROMPT + "\n\n" + "\n\n---\n\n".join(partials)
return [{
"role": "user",
"content": [{"type": "text", "text": merge_text}]
}]
def _checkpoint_path(self, checkpoint_key: str) -> Path:
safe_key = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in checkpoint_key)
return self.checkpoint_dir / f"{safe_key}.gpt.checkpoint.json"
def _build_source_signature(self, source: GPTSource) -> str:
payload = {
"model": self.model,
"temperature": self.temperature,
"max_request_bytes": self.max_request_bytes,
"title": source.title,
"tags": source.tags,
"format": source._format,
"style": source.style,
"extras": source.extras,
"video_img_urls": source.video_img_urls or [],
"segments": [
{
"start": getattr(seg, "start", None),
"end": getattr(seg, "end", None),
"text": getattr(seg, "text", "")
}
for seg in source.segment
],
}
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _load_checkpoint(self, checkpoint_key: str, source_signature: str) -> dict | None:
path = self._checkpoint_path(checkpoint_key)
if not path.exists():
return None
try:
data = json.loads(path.read_text(encoding="utf-8"))
if data.get("source_signature") != source_signature:
path.unlink(missing_ok=True)
return None
return data
except Exception:
path.unlink(missing_ok=True)
return None
def _save_checkpoint(self, checkpoint_key: str, source_signature: str, partials: list, phase: str) -> None:
path = self._checkpoint_path(checkpoint_key)
data = {
"version": 1,
"source_signature": source_signature,
"phase": phase,
"partials": partials,
"updated_at": datetime.now(timezone.utc).isoformat(),
}
tmp_path = path.with_suffix(".tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_path.replace(path)
def _clear_checkpoint(self, checkpoint_key: str) -> None:
self._checkpoint_path(checkpoint_key).unlink(missing_ok=True)
@staticmethod
def _is_insufficient_quota_error(exc: Exception) -> bool:
raw = str(exc)
return (
"insufficient_user_quota" in raw
or "预扣费额度失败" in raw
or "insufficient quota" in raw.lower()
)
@staticmethod
def _is_retryable_error(exc: Exception) -> bool:
raw = str(exc).lower()
retryable_tokens = (
"error code: 524",
"bad_response_status_code",
"timed out",
"timeout",
"rate limit",
"error code: 429",
"error code: 500",
"error code: 502",
"error code: 503",
"error code: 504",
"apiconnectionerror",
"connection error",
"service unavailable",
)
if any(token in raw for token in retryable_tokens):
return True
status = getattr(exc, "status_code", None) or getattr(exc, "status", None)
return status in {408, 409, 429, 500, 502, 503, 504, 524}
def _chat_completion_create(self, messages: list):
max_attempts = max(1, int(os.getenv("OPENAI_RETRY_ATTEMPTS", "3")))
base_backoff = float(os.getenv("OPENAI_RETRY_BACKOFF_SECONDS", "1.5"))
last_exc = None
for attempt in range(max_attempts):
try:
return self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature
)
except Exception as exc:
last_exc = exc
if attempt == max_attempts - 1 or not self._is_retryable_error(exc):
raise
sleep_seconds = base_backoff * (2 ** attempt)
time.sleep(sleep_seconds)
if last_exc is not None:
raise last_exc
raise RuntimeError("chat completion failed without exception")
def _merge_partials(self, partials: list, checkpoint_key: str | None, source_signature: str | None) -> str:
def build_messages(texts, *_args, **_kwargs):
return self._build_merge_messages(texts)
merge_chunker = RequestChunker(
lambda *_args, **_kwargs: [],
self.max_request_bytes,
self._estimate_messages_bytes
)
current_partials = list(partials)
while len(current_partials) > 1:
groups = merge_chunker.group_texts_by_budget(current_partials, build_messages)
new_partials = []
for group_idx, group in enumerate(groups):
messages = build_messages(group)
try:
response = self._chat_completion_create(messages)
except Exception as exc:
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, current_partials, "merge")
raise
new_partials.append(response.choices[0].message.content.strip())
if checkpoint_key and source_signature:
remaining_partials = []
for remaining_group in groups[group_idx + 1:]:
remaining_partials.extend(remaining_group)
resumable_partials = new_partials + remaining_partials
self._save_checkpoint(checkpoint_key, source_signature, resumable_partials, "merge")
current_partials = new_partials
return current_partials[0]
def summarize(self, source: GPTSource) -> str:
self.screenshot = source.screenshot
self.link = source.link
source.segment = self.ensure_segments_type(source.segment)
checkpoint_key = source.checkpoint_key
source_signature = self._build_source_signature(source) if checkpoint_key else None
def message_builder(segments, image_urls, **kwargs):
return self.create_messages(segments, video_img_urls=image_urls, **kwargs)
chunker = RequestChunker(message_builder, self.max_request_bytes, self._estimate_messages_bytes)
try:
chunks = chunker.chunk(
source.segment,
source.video_img_urls or [],
title=source.title,
tags=source.tags,
_format=source._format,
style=source.style,
extras=source.extras
)
except ValueError:
chunks = chunker.chunk(
source.segment,
[],
title=source.title,
tags=source.tags,
_format=source._format,
style=source.style,
extras=source.extras
)
partials = []
if checkpoint_key and source_signature:
checkpoint = self._load_checkpoint(checkpoint_key, source_signature)
if checkpoint and isinstance(checkpoint.get("partials"), list):
partials = checkpoint["partials"]
if len(partials) > len(chunks):
partials = []
for chunk in chunks[len(partials):]:
messages = self.create_messages(
chunk.segments,
title=source.title,
tags=source.tags,
video_img_urls=chunk.image_urls,
_format=source._format,
style=source.style,
extras=source.extras
)
try:
response = self._chat_completion_create(messages)
except Exception as exc:
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
raise
partials.append(response.choices[0].message.content.strip())
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
if len(partials) == 1:
if checkpoint_key:
self._clear_checkpoint(checkpoint_key)
return partials[0]
merged = self._merge_partials(partials, checkpoint_key, source_signature)
if checkpoint_key:
self._clear_checkpoint(checkpoint_key)
return merged