mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-11 18:10:06 +08:00
DeepSeek deepseek-chat 等非多模态模型只接受 ``content`` 为字符串。旧实现在
没有 ``video_img_urls`` 输入时也把 ``content`` 拼成
``[{"type":"text","text":...}]`` 多模态数组,导致 DeepSeek API 返回
``Failed to deserialize the JSON body into the target type: messages[0]:
unknown variant `image_url`, expected `text```,整个笔记生成流程随之崩溃。
修复方式:``create_messages`` 在没有截图时退回 string content;有截图时维持
原多模态数组形态,多模态模型功能不退化。同时把 ``_build_merge_messages`` 也
改为 string content —— 合并阶段从不带图片,旧的数组形态会让长视频 chunk
之后的合并阶段同样命中 DeepSeek 400。
新增 ``backend/tests/test_universal_gpt_content_format.py`` (6 cases):
- 无图片 / 显式空 image 列表都走 string content
- 有图片仍输出多模态数组(含 ``image_url`` + ``detail: auto``)
- 纯文本响应里完全不含 ``image_url`` 字段
- ``_build_merge_messages`` 用 string content + 仍带入 partials 文本
红基线:在不打补丁的 ``universal_gpt.py`` 上跑这 6 个 case,3 个 string-
content 断言会失败(命中 issue #282 的同一根因),打补丁后 6/6 通过。
Closes #282
315 lines
12 KiB
Python
315 lines
12 KiB
Python
from app.gpt.base import GPT
|
||
from app.gpt.prompt_builder import generate_base_prompt
|
||
from app.models.gpt_model import GPTSource
|
||
import os
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK, MERGE_PROMPT
|
||
from app.gpt.utils import fix_markdown
|
||
from app.gpt.request_chunker import RequestChunker
|
||
from app.models.transcriber_model import TranscriptSegment
|
||
from datetime import timedelta
|
||
from typing import List
|
||
|
||
|
||
class UniversalGPT(GPT):
|
||
def __init__(self, client, model: str, temperature: float = 0.7):
|
||
self.client = client
|
||
self.model = model
|
||
self.temperature = temperature
|
||
self.screenshot = False
|
||
self.link = False
|
||
self.max_request_bytes = int(os.getenv("OPENAI_MAX_REQUEST_BYTES", str(45 * 1024 * 1024)))
|
||
self.checkpoint_dir = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
|
||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||
# 初始化时缓存重试配置,避免每次请求重复读取环境变量
|
||
self._max_retry_attempts = max(1, int(os.getenv("OPENAI_RETRY_ATTEMPTS", "3")))
|
||
self._retry_base_backoff = float(os.getenv("OPENAI_RETRY_BACKOFF_SECONDS", "1.5"))
|
||
|
||
def _format_time(self, seconds: float) -> str:
|
||
return str(timedelta(seconds=int(seconds)))[2:]
|
||
|
||
def _build_segment_text(self, segments: List[TranscriptSegment]) -> str:
|
||
return "\n".join(
|
||
f"{self._format_time(seg.start)} - {seg.text.strip()}"
|
||
for seg in segments
|
||
)
|
||
|
||
def ensure_segments_type(self, segments) -> List[TranscriptSegment]:
|
||
return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments]
|
||
|
||
def create_messages(self, segments: List[TranscriptSegment], **kwargs):
|
||
|
||
content_text = generate_base_prompt(
|
||
title=kwargs.get('title'),
|
||
segment_text=self._build_segment_text(segments),
|
||
tags=kwargs.get('tags'),
|
||
_format=kwargs.get('_format'),
|
||
style=kwargs.get('style'),
|
||
extras=kwargs.get('extras'),
|
||
)
|
||
|
||
video_img_urls = kwargs.get('video_img_urls', [])
|
||
|
||
content: list[dict] | str
|
||
if video_img_urls:
|
||
# 有截图时走 OpenAI 多模态 content 数组(text + image_url)
|
||
content = [{"type": "text", "text": content_text}]
|
||
for url in video_img_urls:
|
||
content.append({
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": url,
|
||
"detail": "auto"
|
||
}
|
||
})
|
||
else:
|
||
# 纯文本场景退回 string content:DeepSeek deepseek-chat 等非多模态模型
|
||
# 不识别 [{"type":"text",...}] 数组形态,会返回 invalid_request_error
|
||
# (issue #282)。OpenAI 规范本身也允许 content 为 string。
|
||
content = content_text
|
||
|
||
messages = [{
|
||
"role": "user",
|
||
"content": content
|
||
}]
|
||
|
||
return messages
|
||
|
||
def list_models(self):
|
||
return self.client.models.list()
|
||
|
||
def _estimate_messages_bytes(self, messages: list) -> int:
|
||
import json
|
||
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
|
||
|
||
def _build_merge_messages(self, partials: list) -> list:
|
||
merge_text = MERGE_PROMPT + "\n\n" + "\n\n---\n\n".join(partials)
|
||
# 合并阶段没有图片,直接用 string content 兼容非多模态模型(issue #282)
|
||
return [{
|
||
"role": "user",
|
||
"content": merge_text
|
||
}]
|
||
|
||
def _checkpoint_path(self, checkpoint_key: str) -> Path:
|
||
safe_key = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in checkpoint_key)
|
||
return self.checkpoint_dir / f"{safe_key}.gpt.checkpoint.json"
|
||
|
||
def _build_source_signature(self, source: GPTSource) -> str:
|
||
payload = {
|
||
"model": self.model,
|
||
"temperature": self.temperature,
|
||
"max_request_bytes": self.max_request_bytes,
|
||
"title": source.title,
|
||
"tags": source.tags,
|
||
"format": source._format,
|
||
"style": source.style,
|
||
"extras": source.extras,
|
||
"video_img_urls": source.video_img_urls or [],
|
||
"segments": [
|
||
{
|
||
"start": getattr(seg, "start", None),
|
||
"end": getattr(seg, "end", None),
|
||
"text": getattr(seg, "text", "")
|
||
}
|
||
for seg in source.segment
|
||
],
|
||
}
|
||
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||
|
||
def _load_checkpoint(self, checkpoint_key: str, source_signature: str) -> dict | None:
|
||
path = self._checkpoint_path(checkpoint_key)
|
||
if not path.exists():
|
||
return None
|
||
try:
|
||
data = json.loads(path.read_text(encoding="utf-8"))
|
||
if data.get("source_signature") != source_signature:
|
||
path.unlink(missing_ok=True)
|
||
return None
|
||
return data
|
||
except Exception:
|
||
path.unlink(missing_ok=True)
|
||
return None
|
||
|
||
def _save_checkpoint(self, checkpoint_key: str, source_signature: str, partials: list, phase: str) -> None:
|
||
path = self._checkpoint_path(checkpoint_key)
|
||
data = {
|
||
"version": 1,
|
||
"source_signature": source_signature,
|
||
"phase": phase,
|
||
"partials": partials,
|
||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||
}
|
||
tmp_path = path.with_suffix(".tmp")
|
||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
tmp_path.replace(path)
|
||
|
||
def _clear_checkpoint(self, checkpoint_key: str) -> None:
|
||
self._checkpoint_path(checkpoint_key).unlink(missing_ok=True)
|
||
|
||
@staticmethod
|
||
def _is_insufficient_quota_error(exc: Exception) -> bool:
|
||
raw = str(exc)
|
||
return (
|
||
"insufficient_user_quota" in raw
|
||
or "预扣费额度失败" in raw
|
||
or "insufficient quota" in raw.lower()
|
||
)
|
||
|
||
@staticmethod
|
||
def _is_retryable_error(exc: Exception) -> bool:
|
||
raw = str(exc).lower()
|
||
retryable_tokens = (
|
||
"error code: 524",
|
||
"bad_response_status_code",
|
||
"timed out",
|
||
"timeout",
|
||
"rate limit",
|
||
"error code: 429",
|
||
"error code: 500",
|
||
"error code: 502",
|
||
"error code: 503",
|
||
"error code: 504",
|
||
"apiconnectionerror",
|
||
"connection error",
|
||
"service unavailable",
|
||
)
|
||
if any(token in raw for token in retryable_tokens):
|
||
return True
|
||
|
||
status = getattr(exc, "status_code", None) or getattr(exc, "status", None)
|
||
return status in {408, 409, 429, 500, 502, 503, 504, 524}
|
||
|
||
def _chat_completion_create(self, messages: list):
|
||
last_exc = None
|
||
for attempt in range(self._max_retry_attempts):
|
||
try:
|
||
return self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=self.temperature
|
||
)
|
||
except Exception as exc:
|
||
last_exc = exc
|
||
if attempt == self._max_retry_attempts - 1 or not self._is_retryable_error(exc):
|
||
raise
|
||
sleep_seconds = self._retry_base_backoff * (2 ** attempt)
|
||
time.sleep(sleep_seconds)
|
||
|
||
if last_exc is not None:
|
||
raise last_exc
|
||
raise RuntimeError("chat completion failed without exception")
|
||
|
||
def _merge_partials(self, partials: list, checkpoint_key: str | None, source_signature: str | None) -> str:
|
||
def build_messages(texts, *_args, **_kwargs):
|
||
return self._build_merge_messages(texts)
|
||
|
||
merge_chunker = RequestChunker(
|
||
lambda *_args, **_kwargs: [],
|
||
self.max_request_bytes,
|
||
self._estimate_messages_bytes
|
||
)
|
||
|
||
current_partials = list(partials)
|
||
while len(current_partials) > 1:
|
||
groups = merge_chunker.group_texts_by_budget(current_partials, build_messages)
|
||
new_partials = []
|
||
for group_idx, group in enumerate(groups):
|
||
messages = build_messages(group)
|
||
try:
|
||
response = self._chat_completion_create(messages)
|
||
except Exception as exc:
|
||
if checkpoint_key and source_signature:
|
||
self._save_checkpoint(checkpoint_key, source_signature, current_partials, "merge")
|
||
raise
|
||
|
||
new_partials.append(response.choices[0].message.content.strip())
|
||
|
||
if checkpoint_key and source_signature:
|
||
remaining_partials = []
|
||
for remaining_group in groups[group_idx + 1:]:
|
||
remaining_partials.extend(remaining_group)
|
||
resumable_partials = new_partials + remaining_partials
|
||
self._save_checkpoint(checkpoint_key, source_signature, resumable_partials, "merge")
|
||
|
||
current_partials = new_partials
|
||
|
||
return current_partials[0]
|
||
|
||
def summarize(self, source: GPTSource) -> str:
|
||
self.screenshot = source.screenshot
|
||
self.link = source.link
|
||
source.segment = self.ensure_segments_type(source.segment)
|
||
checkpoint_key = source.checkpoint_key
|
||
source_signature = self._build_source_signature(source) if checkpoint_key else None
|
||
|
||
def message_builder(segments, image_urls, **kwargs):
|
||
return self.create_messages(segments, video_img_urls=image_urls, **kwargs)
|
||
|
||
chunker = RequestChunker(message_builder, self.max_request_bytes, self._estimate_messages_bytes)
|
||
|
||
try:
|
||
chunks = chunker.chunk(
|
||
source.segment,
|
||
source.video_img_urls or [],
|
||
title=source.title,
|
||
tags=source.tags,
|
||
_format=source._format,
|
||
style=source.style,
|
||
extras=source.extras
|
||
)
|
||
except ValueError:
|
||
chunks = chunker.chunk(
|
||
source.segment,
|
||
[],
|
||
title=source.title,
|
||
tags=source.tags,
|
||
_format=source._format,
|
||
style=source.style,
|
||
extras=source.extras
|
||
)
|
||
|
||
partials = []
|
||
if checkpoint_key and source_signature:
|
||
checkpoint = self._load_checkpoint(checkpoint_key, source_signature)
|
||
if checkpoint and isinstance(checkpoint.get("partials"), list):
|
||
partials = checkpoint["partials"]
|
||
|
||
if len(partials) > len(chunks):
|
||
partials = []
|
||
|
||
for chunk in chunks[len(partials):]:
|
||
messages = self.create_messages(
|
||
chunk.segments,
|
||
title=source.title,
|
||
tags=source.tags,
|
||
video_img_urls=chunk.image_urls,
|
||
_format=source._format,
|
||
style=source.style,
|
||
extras=source.extras
|
||
)
|
||
try:
|
||
response = self._chat_completion_create(messages)
|
||
except Exception as exc:
|
||
if checkpoint_key and source_signature:
|
||
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
|
||
raise
|
||
|
||
partials.append(response.choices[0].message.content.strip())
|
||
if checkpoint_key and source_signature:
|
||
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
|
||
|
||
if len(partials) == 1:
|
||
if checkpoint_key:
|
||
self._clear_checkpoint(checkpoint_key)
|
||
return partials[0]
|
||
merged = self._merge_partials(partials, checkpoint_key, source_signature)
|
||
if checkpoint_key:
|
||
self._clear_checkpoint(checkpoint_key)
|
||
return merged
|