Files
BiliNote/backend/app/gpt/universal_gpt.py
voidborne-d 3ff7086491 fix(backend): UniversalGPT.create_messages emit string content when no images
DeepSeek deepseek-chat 等非多模态模型只接受 ``content`` 为字符串。旧实现在
没有 ``video_img_urls`` 输入时也把 ``content`` 拼成
``[{"type":"text","text":...}]`` 多模态数组,导致 DeepSeek API 返回
``Failed to deserialize the JSON body into the target type: messages[0]:
unknown variant `image_url`, expected `text```,整个笔记生成流程随之崩溃。

修复方式:``create_messages`` 在没有截图时退回 string content;有截图时维持
原多模态数组形态,多模态模型功能不退化。同时把 ``_build_merge_messages`` 也
改为 string content —— 合并阶段从不带图片,旧的数组形态会让长视频 chunk
之后的合并阶段同样命中 DeepSeek 400。

新增 ``backend/tests/test_universal_gpt_content_format.py`` (6 cases):

- 无图片 / 显式空 image 列表都走 string content
- 有图片仍输出多模态数组(含 ``image_url`` + ``detail: auto``)
- 纯文本响应里完全不含 ``image_url`` 字段
- ``_build_merge_messages`` 用 string content + 仍带入 partials 文本

红基线:在不打补丁的 ``universal_gpt.py`` 上跑这 6 个 case,3 个 string-
content 断言会失败(命中 issue #282 的同一根因),打补丁后 6/6 通过。

Closes #282
2026-05-07 13:50:59 +08:00

315 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from app.gpt.base import GPT
from app.gpt.prompt_builder import generate_base_prompt
from app.models.gpt_model import GPTSource
import os
import hashlib
import json
import time
from datetime import datetime, timezone
from pathlib import Path
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK, MERGE_PROMPT
from app.gpt.utils import fix_markdown
from app.gpt.request_chunker import RequestChunker
from app.models.transcriber_model import TranscriptSegment
from datetime import timedelta
from typing import List
class UniversalGPT(GPT):
def __init__(self, client, model: str, temperature: float = 0.7):
self.client = client
self.model = model
self.temperature = temperature
self.screenshot = False
self.link = False
self.max_request_bytes = int(os.getenv("OPENAI_MAX_REQUEST_BYTES", str(45 * 1024 * 1024)))
self.checkpoint_dir = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# 初始化时缓存重试配置,避免每次请求重复读取环境变量
self._max_retry_attempts = max(1, int(os.getenv("OPENAI_RETRY_ATTEMPTS", "3")))
self._retry_base_backoff = float(os.getenv("OPENAI_RETRY_BACKOFF_SECONDS", "1.5"))
def _format_time(self, seconds: float) -> str:
return str(timedelta(seconds=int(seconds)))[2:]
def _build_segment_text(self, segments: List[TranscriptSegment]) -> str:
return "\n".join(
f"{self._format_time(seg.start)} - {seg.text.strip()}"
for seg in segments
)
def ensure_segments_type(self, segments) -> List[TranscriptSegment]:
return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments]
def create_messages(self, segments: List[TranscriptSegment], **kwargs):
content_text = generate_base_prompt(
title=kwargs.get('title'),
segment_text=self._build_segment_text(segments),
tags=kwargs.get('tags'),
_format=kwargs.get('_format'),
style=kwargs.get('style'),
extras=kwargs.get('extras'),
)
video_img_urls = kwargs.get('video_img_urls', [])
content: list[dict] | str
if video_img_urls:
# 有截图时走 OpenAI 多模态 content 数组text + image_url
content = [{"type": "text", "text": content_text}]
for url in video_img_urls:
content.append({
"type": "image_url",
"image_url": {
"url": url,
"detail": "auto"
}
})
else:
# 纯文本场景退回 string contentDeepSeek deepseek-chat 等非多模态模型
# 不识别 [{"type":"text",...}] 数组形态,会返回 invalid_request_error
# issue #282。OpenAI 规范本身也允许 content 为 string。
content = content_text
messages = [{
"role": "user",
"content": content
}]
return messages
def list_models(self):
return self.client.models.list()
def _estimate_messages_bytes(self, messages: list) -> int:
import json
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
def _build_merge_messages(self, partials: list) -> list:
merge_text = MERGE_PROMPT + "\n\n" + "\n\n---\n\n".join(partials)
# 合并阶段没有图片,直接用 string content 兼容非多模态模型issue #282
return [{
"role": "user",
"content": merge_text
}]
def _checkpoint_path(self, checkpoint_key: str) -> Path:
safe_key = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in checkpoint_key)
return self.checkpoint_dir / f"{safe_key}.gpt.checkpoint.json"
def _build_source_signature(self, source: GPTSource) -> str:
payload = {
"model": self.model,
"temperature": self.temperature,
"max_request_bytes": self.max_request_bytes,
"title": source.title,
"tags": source.tags,
"format": source._format,
"style": source.style,
"extras": source.extras,
"video_img_urls": source.video_img_urls or [],
"segments": [
{
"start": getattr(seg, "start", None),
"end": getattr(seg, "end", None),
"text": getattr(seg, "text", "")
}
for seg in source.segment
],
}
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _load_checkpoint(self, checkpoint_key: str, source_signature: str) -> dict | None:
path = self._checkpoint_path(checkpoint_key)
if not path.exists():
return None
try:
data = json.loads(path.read_text(encoding="utf-8"))
if data.get("source_signature") != source_signature:
path.unlink(missing_ok=True)
return None
return data
except Exception:
path.unlink(missing_ok=True)
return None
def _save_checkpoint(self, checkpoint_key: str, source_signature: str, partials: list, phase: str) -> None:
path = self._checkpoint_path(checkpoint_key)
data = {
"version": 1,
"source_signature": source_signature,
"phase": phase,
"partials": partials,
"updated_at": datetime.now(timezone.utc).isoformat(),
}
tmp_path = path.with_suffix(".tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_path.replace(path)
def _clear_checkpoint(self, checkpoint_key: str) -> None:
self._checkpoint_path(checkpoint_key).unlink(missing_ok=True)
@staticmethod
def _is_insufficient_quota_error(exc: Exception) -> bool:
raw = str(exc)
return (
"insufficient_user_quota" in raw
or "预扣费额度失败" in raw
or "insufficient quota" in raw.lower()
)
@staticmethod
def _is_retryable_error(exc: Exception) -> bool:
raw = str(exc).lower()
retryable_tokens = (
"error code: 524",
"bad_response_status_code",
"timed out",
"timeout",
"rate limit",
"error code: 429",
"error code: 500",
"error code: 502",
"error code: 503",
"error code: 504",
"apiconnectionerror",
"connection error",
"service unavailable",
)
if any(token in raw for token in retryable_tokens):
return True
status = getattr(exc, "status_code", None) or getattr(exc, "status", None)
return status in {408, 409, 429, 500, 502, 503, 504, 524}
def _chat_completion_create(self, messages: list):
last_exc = None
for attempt in range(self._max_retry_attempts):
try:
return self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature
)
except Exception as exc:
last_exc = exc
if attempt == self._max_retry_attempts - 1 or not self._is_retryable_error(exc):
raise
sleep_seconds = self._retry_base_backoff * (2 ** attempt)
time.sleep(sleep_seconds)
if last_exc is not None:
raise last_exc
raise RuntimeError("chat completion failed without exception")
def _merge_partials(self, partials: list, checkpoint_key: str | None, source_signature: str | None) -> str:
def build_messages(texts, *_args, **_kwargs):
return self._build_merge_messages(texts)
merge_chunker = RequestChunker(
lambda *_args, **_kwargs: [],
self.max_request_bytes,
self._estimate_messages_bytes
)
current_partials = list(partials)
while len(current_partials) > 1:
groups = merge_chunker.group_texts_by_budget(current_partials, build_messages)
new_partials = []
for group_idx, group in enumerate(groups):
messages = build_messages(group)
try:
response = self._chat_completion_create(messages)
except Exception as exc:
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, current_partials, "merge")
raise
new_partials.append(response.choices[0].message.content.strip())
if checkpoint_key and source_signature:
remaining_partials = []
for remaining_group in groups[group_idx + 1:]:
remaining_partials.extend(remaining_group)
resumable_partials = new_partials + remaining_partials
self._save_checkpoint(checkpoint_key, source_signature, resumable_partials, "merge")
current_partials = new_partials
return current_partials[0]
def summarize(self, source: GPTSource) -> str:
self.screenshot = source.screenshot
self.link = source.link
source.segment = self.ensure_segments_type(source.segment)
checkpoint_key = source.checkpoint_key
source_signature = self._build_source_signature(source) if checkpoint_key else None
def message_builder(segments, image_urls, **kwargs):
return self.create_messages(segments, video_img_urls=image_urls, **kwargs)
chunker = RequestChunker(message_builder, self.max_request_bytes, self._estimate_messages_bytes)
try:
chunks = chunker.chunk(
source.segment,
source.video_img_urls or [],
title=source.title,
tags=source.tags,
_format=source._format,
style=source.style,
extras=source.extras
)
except ValueError:
chunks = chunker.chunk(
source.segment,
[],
title=source.title,
tags=source.tags,
_format=source._format,
style=source.style,
extras=source.extras
)
partials = []
if checkpoint_key and source_signature:
checkpoint = self._load_checkpoint(checkpoint_key, source_signature)
if checkpoint and isinstance(checkpoint.get("partials"), list):
partials = checkpoint["partials"]
if len(partials) > len(chunks):
partials = []
for chunk in chunks[len(partials):]:
messages = self.create_messages(
chunk.segments,
title=source.title,
tags=source.tags,
video_img_urls=chunk.image_urls,
_format=source._format,
style=source.style,
extras=source.extras
)
try:
response = self._chat_completion_create(messages)
except Exception as exc:
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
raise
partials.append(response.choices[0].message.content.strip())
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
if len(partials) == 1:
if checkpoint_key:
self._clear_checkpoint(checkpoint_key)
return partials[0]
merged = self._merge_partials(partials, checkpoint_key, source_signature)
if checkpoint_key:
self._clear_checkpoint(checkpoint_key)
return merged