🐞 fix: 增加错误之后对已解析段落的缓存功能,再次重试时不再重头开始

解析长视频时,当附件大小过大时不再调用后进行报错,而是将附件进行分批次发送

在每篇笔记开头默认增加地址来源链接,对模糊处可溯源
This commit is contained in:
CyanAutumn
2026-02-12 18:28:11 +08:00
parent 7b45db2f59
commit d9a7b89e7d
67 changed files with 279293 additions and 64 deletions

View File

@@ -18,12 +18,12 @@ BASE_PROMPT = '''
- **不要**将输出包裹在代码块中(例如:```` ```markdown ```````` ``` ````)。
请注意,在生成 Markdown 时避免将编号标题如“1. **内容**”)写成有序列表的格式,以免解析错误。
- 如果要加粗并保留编号,应使用 `1\. **内容**`(加反斜杠),防止被误解析为有序列表。
- 如果要加粗并保留编号,应使用 `1\\. **内容**`(加反斜杠),防止被误解析为有序列表。
- 或者使用 `## 1. 内容` 的形式作为标题。
请确保以下格式 **不会出现误渲染**
`1. **xxx**`
`1\. **xxx**` 或 `## 1. xxx`
`1\\. **xxx**` 或 `## 1. xxx`
视频分段(格式:开始时间 - 内容):
@@ -66,4 +66,13 @@ SCREENSHOT='''
8. **Screenshot placeholders**: If a section involves **visual demonstrations, code walkthroughs, UI interactions**, or any content where visuals aid understanding, insert a screenshot cue at the end of that section:
- Format: `*Screenshot-[mm:ss]`
- Only use it when truly helpful.
'''
'''
MERGE_PROMPT = '''
你将收到多个来自同一视频的 Markdown 笔记片段,请合并成一份完整笔记:
- 只做合并与去重,不要发明新内容
- 保持原有标题层级与 Markdown 结构
- 保留所有 *Content-[mm:ss] 与 *Screenshot-[mm:ss] 标记
- 保持中文输出,专有名词保留英文
- 不要使用代码块包裹输出
'''

View File

@@ -0,0 +1,161 @@
from dataclasses import dataclass
from typing import Callable, List, Optional
@dataclass
class ChunkPayload:
segments: list
image_urls: list
class RequestChunker:
def __init__(self, message_builder: Callable, max_bytes: int, size_estimator: Optional[Callable] = None):
self.message_builder = message_builder
self.max_bytes = max_bytes
self.size_estimator = size_estimator
def estimate(self, messages) -> int:
if self.size_estimator:
return self.size_estimator(messages)
import json
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
def _messages_size(self, segments, image_urls, **kwargs) -> int:
messages = self.message_builder(segments, image_urls, **kwargs)
return self.estimate(messages)
def _get_text(self, segment) -> str:
if isinstance(segment, dict):
return segment.get("text", "")
return getattr(segment, "text", "")
def _make_segment(self, segment, text: str):
if isinstance(segment, dict):
new_seg = dict(segment)
new_seg["text"] = text
return new_seg
if hasattr(segment, "__dict__"):
data = dict(segment.__dict__)
data["text"] = text
return type(segment)(**data)
return type(segment)(segment.start, segment.end, text)
def _split_segment_to_fit(self, segment, **kwargs):
text = self._get_text(segment)
if not text:
raise ValueError("empty segment cannot be split")
lo, hi = 1, len(text)
best = None
while lo <= hi:
mid = (lo + hi) // 2
candidate = self._make_segment(segment, text[:mid])
size = self._messages_size([candidate], [], **kwargs)
if size <= self.max_bytes:
best = mid
lo = mid + 1
else:
hi = mid - 1
if best is None:
raise ValueError("single segment too large to fit request")
head = self._make_segment(segment, text[:best])
tail = self._make_segment(segment, text[best:])
return head, tail
def chunk(self, segments: list, image_urls: list, **kwargs) -> List[ChunkPayload]:
segments = list(segments or [])
image_urls = list(image_urls or [])
if not segments and not image_urls:
return []
chunks: List[ChunkPayload] = []
seg_idx = 0
while seg_idx < len(segments):
batch_segments = []
while seg_idx < len(segments):
candidate = batch_segments + [segments[seg_idx]]
size = self._messages_size(candidate, [], **kwargs)
if size <= self.max_bytes:
batch_segments = candidate
seg_idx += 1
continue
if not batch_segments:
head, tail = self._split_segment_to_fit(segments[seg_idx], **kwargs)
segments[seg_idx] = head
segments.insert(seg_idx + 1, tail)
continue
break
if not batch_segments:
raise ValueError("unable to fit any content into chunk")
chunks.append(ChunkPayload(segments=batch_segments, image_urls=[]))
if not image_urls:
return chunks
if not chunks:
chunks = [ChunkPayload(segments=[], image_urls=[])]
if not segments:
for image in image_urls:
appended = False
for chunk in chunks[-1:]:
candidate_images = chunk.image_urls + [image]
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
chunk.image_urls = candidate_images
appended = True
break
if appended:
continue
if self._messages_size([], [image], **kwargs) > self.max_bytes:
raise ValueError("single image payload exceeds max_bytes")
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
return chunks
chunk_count = len(chunks)
total_images = len(image_urls)
for idx, image in enumerate(image_urls):
preferred_idx = min(chunk_count - 1, (idx * chunk_count) // total_images)
placed = False
for chunk_idx in range(preferred_idx, len(chunks)):
chunk = chunks[chunk_idx]
candidate_images = chunk.image_urls + [image]
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
chunk.image_urls = candidate_images
placed = True
break
if placed:
continue
if self._messages_size([], [image], **kwargs) > self.max_bytes:
raise ValueError("single image payload exceeds max_bytes")
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
return chunks
def group_texts_by_budget(self, texts: List[str], build_messages: Callable, **kwargs) -> List[List[str]]:
groups: List[List[str]] = []
idx = 0
while idx < len(texts):
group: List[str] = []
while idx < len(texts):
candidate = group + [texts[idx]]
try:
messages = build_messages(candidate, [], **kwargs)
except TypeError:
messages = build_messages(candidate, **kwargs)
size = self.estimate(messages)
if size <= self.max_bytes:
group = candidate
idx += 1
continue
if not group:
raise ValueError("single text block exceeds max_bytes")
break
groups.append(group)
return groups

View File

@@ -1,8 +1,16 @@
from app.gpt.base import GPT
from app.gpt.prompt_builder import generate_base_prompt
from app.models.gpt_model import GPTSource
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
import os
import hashlib
import json
import time
from datetime import datetime, timezone
from pathlib import Path
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK, MERGE_PROMPT
from app.gpt.utils import fix_markdown
from app.gpt.request_chunker import RequestChunker
from app.models.transcriber_model import TranscriptSegment
from datetime import timedelta
from typing import List
@@ -15,6 +23,9 @@ class UniversalGPT(GPT):
self.temperature = temperature
self.screenshot = False
self.link = False
self.max_request_bytes = int(os.getenv("OPENAI_MAX_REQUEST_BYTES", str(45 * 1024 * 1024)))
self.checkpoint_dir = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
def _format_time(self, seconds: float) -> str:
return str(timedelta(seconds=int(seconds)))[2:]
@@ -40,7 +51,7 @@ class UniversalGPT(GPT):
)
# ⛳ 组装 content 数组,支持 text + image_url 混合
content = [{"type": "text", "text": content_text}]
content: List[dict] = [{"type": "text", "text": content_text}]
video_img_urls = kwargs.get('video_img_urls', [])
for url in video_img_urls:
@@ -63,23 +74,234 @@ class UniversalGPT(GPT):
def list_models(self):
return self.client.models.list()
def _estimate_messages_bytes(self, messages: list) -> int:
import json
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
def _build_merge_messages(self, partials: list) -> list:
merge_text = MERGE_PROMPT + "\n\n" + "\n\n---\n\n".join(partials)
return [{
"role": "user",
"content": [{"type": "text", "text": merge_text}]
}]
def _checkpoint_path(self, checkpoint_key: str) -> Path:
safe_key = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in checkpoint_key)
return self.checkpoint_dir / f"{safe_key}.gpt.checkpoint.json"
def _build_source_signature(self, source: GPTSource) -> str:
payload = {
"model": self.model,
"temperature": self.temperature,
"max_request_bytes": self.max_request_bytes,
"title": source.title,
"tags": source.tags,
"format": source._format,
"style": source.style,
"extras": source.extras,
"video_img_urls": source.video_img_urls or [],
"segments": [
{
"start": getattr(seg, "start", None),
"end": getattr(seg, "end", None),
"text": getattr(seg, "text", "")
}
for seg in source.segment
],
}
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _load_checkpoint(self, checkpoint_key: str, source_signature: str) -> dict | None:
path = self._checkpoint_path(checkpoint_key)
if not path.exists():
return None
try:
data = json.loads(path.read_text(encoding="utf-8"))
if data.get("source_signature") != source_signature:
path.unlink(missing_ok=True)
return None
return data
except Exception:
path.unlink(missing_ok=True)
return None
def _save_checkpoint(self, checkpoint_key: str, source_signature: str, partials: list, phase: str) -> None:
path = self._checkpoint_path(checkpoint_key)
data = {
"version": 1,
"source_signature": source_signature,
"phase": phase,
"partials": partials,
"updated_at": datetime.now(timezone.utc).isoformat(),
}
tmp_path = path.with_suffix(".tmp")
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_path.replace(path)
def _clear_checkpoint(self, checkpoint_key: str) -> None:
self._checkpoint_path(checkpoint_key).unlink(missing_ok=True)
@staticmethod
def _is_insufficient_quota_error(exc: Exception) -> bool:
raw = str(exc)
return (
"insufficient_user_quota" in raw
or "预扣费额度失败" in raw
or "insufficient quota" in raw.lower()
)
@staticmethod
def _is_retryable_error(exc: Exception) -> bool:
raw = str(exc).lower()
retryable_tokens = (
"error code: 524",
"bad_response_status_code",
"timed out",
"timeout",
"rate limit",
"error code: 429",
"error code: 500",
"error code: 502",
"error code: 503",
"error code: 504",
"apiconnectionerror",
"connection error",
"service unavailable",
)
if any(token in raw for token in retryable_tokens):
return True
status = getattr(exc, "status_code", None) or getattr(exc, "status", None)
return status in {408, 409, 429, 500, 502, 503, 504, 524}
def _chat_completion_create(self, messages: list):
max_attempts = max(1, int(os.getenv("OPENAI_RETRY_ATTEMPTS", "3")))
base_backoff = float(os.getenv("OPENAI_RETRY_BACKOFF_SECONDS", "1.5"))
last_exc = None
for attempt in range(max_attempts):
try:
return self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature
)
except Exception as exc:
last_exc = exc
if attempt == max_attempts - 1 or not self._is_retryable_error(exc):
raise
sleep_seconds = base_backoff * (2 ** attempt)
time.sleep(sleep_seconds)
if last_exc is not None:
raise last_exc
raise RuntimeError("chat completion failed without exception")
def _merge_partials(self, partials: list, checkpoint_key: str | None, source_signature: str | None) -> str:
def build_messages(texts, *_args, **_kwargs):
return self._build_merge_messages(texts)
merge_chunker = RequestChunker(
lambda *_args, **_kwargs: [],
self.max_request_bytes,
self._estimate_messages_bytes
)
current_partials = list(partials)
while len(current_partials) > 1:
groups = merge_chunker.group_texts_by_budget(current_partials, build_messages)
new_partials = []
for group_idx, group in enumerate(groups):
messages = build_messages(group)
try:
response = self._chat_completion_create(messages)
except Exception as exc:
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, current_partials, "merge")
raise
new_partials.append(response.choices[0].message.content.strip())
if checkpoint_key and source_signature:
remaining_partials = []
for remaining_group in groups[group_idx + 1:]:
remaining_partials.extend(remaining_group)
resumable_partials = new_partials + remaining_partials
self._save_checkpoint(checkpoint_key, source_signature, resumable_partials, "merge")
current_partials = new_partials
return current_partials[0]
def summarize(self, source: GPTSource) -> str:
self.screenshot = source.screenshot
self.link = source.link
source.segment = self.ensure_segments_type(source.segment)
checkpoint_key = source.checkpoint_key
source_signature = self._build_source_signature(source) if checkpoint_key else None
messages = self.create_messages(
source.segment,
title=source.title,
tags=source.tags,
video_img_urls=source.video_img_urls,
_format=source._format,
style=source.style,
extras=source.extras
)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.7
)
return response.choices[0].message.content.strip()
def message_builder(segments, image_urls, **kwargs):
return self.create_messages(segments, video_img_urls=image_urls, **kwargs)
chunker = RequestChunker(message_builder, self.max_request_bytes, self._estimate_messages_bytes)
try:
chunks = chunker.chunk(
source.segment,
source.video_img_urls or [],
title=source.title,
tags=source.tags,
_format=source._format,
style=source.style,
extras=source.extras
)
except ValueError:
chunks = chunker.chunk(
source.segment,
[],
title=source.title,
tags=source.tags,
_format=source._format,
style=source.style,
extras=source.extras
)
partials = []
if checkpoint_key and source_signature:
checkpoint = self._load_checkpoint(checkpoint_key, source_signature)
if checkpoint and isinstance(checkpoint.get("partials"), list):
partials = checkpoint["partials"]
if len(partials) > len(chunks):
partials = []
for chunk in chunks[len(partials):]:
messages = self.create_messages(
chunk.segments,
title=source.title,
tags=source.tags,
video_img_urls=chunk.image_urls,
_format=source._format,
style=source.style,
extras=source.extras
)
try:
response = self._chat_completion_create(messages)
except Exception as exc:
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
raise
partials.append(response.choices[0].message.content.strip())
if checkpoint_key and source_signature:
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
if len(partials) == 1:
if checkpoint_key:
self._clear_checkpoint(checkpoint_key)
return partials[0]
merged = self._merge_partials(partials, checkpoint_key, source_signature)
if checkpoint_key:
self._clear_checkpoint(checkpoint_key)
return merged

View File

@@ -15,4 +15,5 @@ class GPTSource:
extras: Optional[str] = None
_format: Optional[list] = None
video_img_urls: Optional[list] = None
checkpoint_key: Optional[str] = None

View File

@@ -15,6 +15,7 @@ from app.enmus.exception import NoteErrorEnum
from app.enmus.note_enums import DownloadQuality
from app.exceptions.note import NoteError
from app.services.note import NoteGenerator, logger
from app.services.task_serial_executor import task_serial_executor
from app.utils.response import ResponseWrapper as R
from app.utils.url_parser import extract_video_id
from app.validators.video_url_validator import is_supported_video_url
@@ -82,22 +83,26 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
if not model_name or not provider_id:
raise HTTPException(status_code=400, detail="请选择模型和提供者")
note = NoteGenerator().generate(
video_url=video_url,
platform=platform,
quality=quality,
task_id=task_id,
model_name=model_name,
provider_id=provider_id,
link=link,
_format=_format,
style=style,
extras=extras,
screenshot=screenshot
, video_understanding=video_understanding,
video_interval=video_interval,
grid_size=grid_size
)
def _execute_note_task():
return NoteGenerator().generate(
video_url=video_url,
platform=platform,
quality=quality,
task_id=task_id,
model_name=model_name,
provider_id=provider_id,
link=link,
_format=_format,
style=style,
extras=extras,
screenshot=screenshot,
video_understanding=video_understanding,
video_interval=video_interval,
grid_size=grid_size,
)
logger.info(f"任务进入串行队列,等待执行 (task_id={task_id})")
note = task_serial_executor.run(_execute_note_task)
logger.info(f"Note generated: {task_id}")
if not note or not note.markdown:
logger.warning(f"任务 {task_id} 执行失败,跳过保存")
@@ -144,13 +149,14 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
if data.task_id:
# 如果传了task_id说明是重试
task_id = data.task_id
# 更新之前的状态
NoteGenerator()._update_status(task_id, TaskStatus.PENDING)
logger.info(f"重试模式,复用已有 task_id={task_id}")
else:
# 正常新建任务
task_id = str(uuid.uuid4())
# 统一先写入 PENDING表示已进入队列等待串行执行
NoteGenerator()._update_status(task_id, TaskStatus.PENDING)
background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality, data.link,
data.screenshot, data.model_name, data.provider_id, data.format, data.style,
data.extras, data.video_understanding, data.video_interval, data.grid_size)

View File

@@ -1,7 +1,6 @@
import json
import logging
import os
import re
from dataclasses import asdict
from pathlib import Path
from typing import List, Optional, Tuple, Union, Any
@@ -32,7 +31,8 @@ from app.services.constant import SUPPORT_PLATFORM_MAP
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
from app.utils.note_helper import replace_content_markers
from app.utils.note_helper import replace_content_markers, prepend_source_link
from app.utils.screenshot_marker import extract_screenshot_timestamps
from app.utils.status_code import StatusCode
from app.utils.video_helper import generate_screenshot
from app.utils.video_reader import VideoReader
@@ -182,6 +182,8 @@ class NoteGenerator:
platform=platform,
)
markdown = prepend_source_link(markdown, str(video_url))
# 5. 保存记录到数据库
self._update_status(task_id, TaskStatus.SAVING)
self._save_metadata(video_id=audio_meta.video_id, platform=platform, task_id=task_id)
@@ -353,6 +355,10 @@ class NoteGenerator:
# 判断是否需要下载视频
need_video = screenshot or video_understanding
if screenshot and not grid_size:
grid_size = [2, 2]
frame_interval = video_interval if video_interval and video_interval > 0 else 6
if need_video:
try:
logger.info("开始下载视频")
@@ -365,10 +371,10 @@ class NoteGenerator:
self.video_img_urls=VideoReader(
video_path=str(self.video_path),
grid_size=tuple(grid_size),
frame_interval=video_interval,
unit_width=1280,
unit_height=720,
save_quality=90,
frame_interval=frame_interval,
unit_width=960,
unit_height=540,
save_quality=80,
).run()
else:
logger.info("未指定 grid_size跳过缩略图生成")
@@ -540,6 +546,7 @@ class NoteGenerator:
_format=formats,
style=style,
extras=extras,
checkpoint_key=task_id,
)
try:
@@ -592,7 +599,7 @@ class NoteGenerator:
:param video_path: 本地视频文件路径
:return: 替换后的 Markdown 字符串
"""
matches: List[Tuple[str, int]] = self._extract_screenshot_timestamps(markdown)
matches: List[Tuple[str, int]] = extract_screenshot_timestamps(markdown)
for idx, (marker, ts) in enumerate(matches):
try:
img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx)
@@ -615,14 +622,7 @@ class NoteGenerator:
:param markdown: 原始 Markdown 文本
:return: 标记与对应时间戳秒数的列表
"""
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
results: List[Tuple[str, int]] = []
for match in re.finditer(pattern, markdown):
mm = match.group(1) or match.group(3)
ss = match.group(2) or match.group(4)
total_seconds = int(mm) * 60 + int(ss)
results.append((match.group(0), total_seconds))
return results
return extract_screenshot_timestamps(markdown)
def _save_metadata(self, video_id: str, platform: str, task_id: str) -> None:
"""
@@ -636,4 +636,4 @@ class NoteGenerator:
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
logger.info(f"已保存任务记录到数据库 (video_id={video_id}, platform={platform}, task_id={task_id})")
except Exception as e:
logger.error(f"保存任务记录失败:{e}")
logger.error(f"保存任务记录失败:{e}")

View File

@@ -0,0 +1,14 @@
import threading
from typing import Any, Callable
class SerialTaskExecutor:
def __init__(self):
self._lock = threading.Lock()
def run(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
with self._lock:
return fn(*args, **kwargs)
task_serial_executor = SerialTaskExecutor()

View File

@@ -5,6 +5,37 @@ import re
import re
def prepend_source_link(markdown: str | None, source_url: str) -> str | None:
"""
在笔记开头添加来源链接;若首个非空行已包含来源链接,则更新该行并避免重复。
"""
if markdown is None:
return None
source = (source_url or "").strip()
if not source:
return markdown
header = f"> 来源链接:{source}"
lines = markdown.splitlines()
first_non_empty_idx = None
for idx, line in enumerate(lines):
if line.strip():
first_non_empty_idx = idx
break
if first_non_empty_idx is not None:
first_line = lines[first_non_empty_idx].strip()
if first_line.startswith("> 来源链接:") or first_line.startswith("来源链接:"):
lines[first_non_empty_idx] = header
return "\n".join(lines)
if markdown.strip():
return f"{header}\n\n{markdown}"
return header
def replace_content_markers(markdown: str, video_id: str, platform: str = 'bilibili') -> str:
"""
替换 *Content-04:16*、Content-04:16 或 Content-[04:16] 为超链接,跳转到对应平台视频的时间位置
@@ -12,18 +43,20 @@ def replace_content_markers(markdown: str, video_id: str, platform: str = 'bilib
# 匹配三种形式:*Content-04:16*、Content-04:16、Content-[04:16]
pattern = r"(?:\*?)Content-(?:\[(\d{2}):(\d{2})\]|(\d{2}):(\d{2}))"
safe_video_id = video_id
def replacer(match):
mm = match.group(1) or match.group(3)
ss = match.group(2) or match.group(4)
total_seconds = int(mm) * 60 + int(ss)
if platform == 'bilibili':
video_id = video_id.replace("_p", "?p=")
url = f"https://www.bilibili.com/video/{video_id}&t={total_seconds}"
parsed_video_id = safe_video_id.replace("_p", "?p=")
url = f"https://www.bilibili.com/video/{parsed_video_id}&t={total_seconds}"
elif platform == 'youtube':
url = f"https://www.youtube.com/watch?v={video_id}&t={total_seconds}s"
url = f"https://www.youtube.com/watch?v={safe_video_id}&t={total_seconds}s"
elif platform == 'douyin':
url = f"https://www.douyin.com/video/{video_id}"
url = f"https://www.douyin.com/video/{safe_video_id}"
return f"[原片 @ {mm}:{ss}]({url})"
else:
return f"({mm}:{ss})"

View File

@@ -0,0 +1,13 @@
import re
from typing import List, Tuple
def extract_screenshot_timestamps(markdown: str) -> List[Tuple[str, int]]:
pattern = r"(\*?Screenshot-(?:\[(\d{2}):(\d{2})\]|(\d{2}):(\d{2})))"
results: List[Tuple[str, int]] = []
for match in re.finditer(pattern, markdown):
mm = match.group(2) or match.group(4)
ss = match.group(3) or match.group(5)
total_seconds = int(mm) * 60 + int(ss)
results.append((match.group(1), total_seconds))
return results

View File

@@ -1,4 +1,5 @@
import base64
import hashlib
import os
import re
import subprocess
@@ -14,6 +15,7 @@ class VideoReader:
video_path: str,
grid_size=(3, 3),
frame_interval=2,
dedupe_enabled=True,
unit_width=960,
unit_height=540,
save_quality=90,
@@ -23,6 +25,7 @@ class VideoReader:
self.video_path = video_path
self.grid_size = grid_size
self.frame_interval = frame_interval
self.dedupe_enabled = dedupe_enabled
self.unit_width = unit_width
self.unit_height = unit_height
self.save_quality = save_quality
@@ -31,6 +34,14 @@ class VideoReader:
print(f"视频路径:{video_path}",self.frame_dir,self.grid_dir)
self.font_path = font_path
@staticmethod
def _calculate_file_md5(file_path: str) -> str:
hasher = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hasher.update(chunk)
return hasher.hexdigest()
def format_time(self, seconds: float) -> str:
mm = int(seconds // 60)
ss = int(seconds % 60)
@@ -51,12 +62,21 @@ class VideoReader:
timestamps = [i for i in range(0, int(duration), self.frame_interval)][:max_frames]
image_paths = []
last_hash = None
for ts in timestamps:
time_label = self.format_time(ts)
output_path = os.path.join(self.frame_dir, f"frame_{time_label}.jpg")
cmd = ["ffmpeg", "-ss", str(ts), "-i", self.video_path, "-frames:v", "1", "-q:v", "2", "-y", output_path,
"-hide_banner", "-loglevel", "error"]
subprocess.run(cmd, check=True)
if self.dedupe_enabled:
frame_hash = self._calculate_file_md5(output_path)
if frame_hash == last_hash:
os.remove(output_path)
continue
last_hash = frame_hash
image_paths.append(output_path)
return image_paths
except Exception as e:

1
backend/run.bat Normal file
View File

@@ -0,0 +1 @@
python main.py

View File

View File

@@ -0,0 +1,35 @@
import importlib.util
import pathlib
import unittest
ROOT = pathlib.Path(__file__).resolve().parents[1]
MODULE_PATH = ROOT / "app" / "utils" / "note_helper.py"
spec = importlib.util.spec_from_file_location("note_helper", MODULE_PATH)
if spec is None or spec.loader is None:
raise ImportError("note_helper module spec not found")
note_helper = importlib.util.module_from_spec(spec)
spec.loader.exec_module(note_helper)
class TestNoteHelper(unittest.TestCase):
def test_prepend_source_link_adds_header_at_top(self):
source_url = "https://www.bilibili.com/video/BV1xx411c7mD"
markdown = "## 标题\n\n内容"
result = note_helper.prepend_source_link(markdown, source_url)
self.assertTrue(result.startswith(f"> 来源链接:{source_url}\n\n"))
self.assertIn("## 标题", result)
def test_prepend_source_link_does_not_duplicate_when_header_exists(self):
source_url = "https://www.youtube.com/watch?v=abc123"
markdown = f"> 来源链接:{source_url}\n\n## 标题\n\n内容"
result = note_helper.prepend_source_link(markdown, source_url)
self.assertEqual(result, markdown)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,97 @@
import importlib.util
import pathlib
import unittest
from dataclasses import dataclass
ROOT = pathlib.Path(__file__).resolve().parents[1]
MODULE_PATH = ROOT / "app" / "gpt" / "request_chunker.py"
spec = importlib.util.spec_from_file_location("request_chunker", MODULE_PATH)
if spec is None or spec.loader is None:
raise ImportError("request_chunker module spec not found")
request_chunker = importlib.util.module_from_spec(spec)
spec.loader.exec_module(request_chunker)
RequestChunker = request_chunker.RequestChunker
@dataclass
class DummySeg:
start: float
end: float
text: str
def build_messages(segments, image_urls, **_):
content = [{"type": "text", "text": "".join(s.text for s in segments)}]
for url in image_urls:
content.append({"type": "image_url", "image_url": {"url": url, "detail": "auto"}})
return [{"role": "user", "content": content}]
def size_estimator(messages):
size = 0
for part in messages[0]["content"]:
if part["type"] == "text":
size += len(part["text"])
else:
size += len(part["image_url"]["url"])
return size
class TestRequestChunker(unittest.TestCase):
def test_chunk_segments_preserves_order_and_content(self):
segments = [
DummySeg(0, 1, "aaaa"),
DummySeg(1, 2, "bbbb"),
DummySeg(2, 3, "cccc"),
]
chunker = RequestChunker(build_messages, max_bytes=8, size_estimator=size_estimator)
chunks = chunker.chunk(segments, [])
texts = ["".join(seg.text for seg in c.segments) for c in chunks]
self.assertEqual("".join(texts), "aaaabbbbcccc")
self.assertTrue(all(texts))
def test_chunk_images_distributed_across_batches(self):
segments = [DummySeg(0, 1, "aa")]
images = ["i" * 6, "j" * 6, "k" * 6]
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
chunks = chunker.chunk(segments, images)
all_images = [img for c in chunks for img in c.image_urls]
self.assertEqual(all_images, images)
def test_chunk_images_are_not_front_loaded_when_multiple_segment_chunks(self):
segments = [
DummySeg(0, 1, "aaaaaa"),
DummySeg(1, 2, "bbbbbb"),
DummySeg(2, 3, "cccccc"),
]
images = ["11111", "22222", "33333"]
chunker = RequestChunker(build_messages, max_bytes=12, size_estimator=size_estimator)
chunks = chunker.chunk(segments, images)
self.assertGreaterEqual(len(chunks), 3)
image_counts = [len(c.image_urls) for c in chunks]
self.assertGreater(image_counts[1], 0)
self.assertGreater(image_counts[2], 0)
all_images = [img for c in chunks for img in c.image_urls]
self.assertEqual(all_images, images)
def test_split_oversized_segment(self):
segments = [DummySeg(0, 1, "x" * 25)]
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
chunks = chunker.chunk(segments, [])
combined = "".join(seg.text for c in chunks for seg in c.segments)
self.assertEqual(combined, "x" * 25)
def test_group_texts_by_budget(self):
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
def build_text_messages(texts, *_args, **_kwargs):
content = [{"type": "text", "text": "".join(texts)}]
return [{"role": "user", "content": content}]
groups = chunker.group_texts_by_budget(["aaaaa", "bbbbb", "ccccc"], build_text_messages)
self.assertEqual(groups, [["aaaaa", "bbbbb"], ["ccccc"]])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,35 @@
import importlib.util
import pathlib
import unittest
ROOT = pathlib.Path(__file__).resolve().parents[1]
MODULE_PATH = ROOT / "app" / "utils" / "screenshot_marker.py"
spec = importlib.util.spec_from_file_location("screenshot_marker", MODULE_PATH)
if spec is None or spec.loader is None:
raise ImportError("screenshot_marker module spec not found")
screenshot_marker = importlib.util.module_from_spec(spec)
spec.loader.exec_module(screenshot_marker)
extract_screenshot_timestamps = screenshot_marker.extract_screenshot_timestamps
class TestScreenshotMarker(unittest.TestCase):
def test_extract_accepts_star_bracket_format(self):
markdown = "A\n*Screenshot-[01:02]\nB"
matches = extract_screenshot_timestamps(markdown)
self.assertEqual(matches, [("*Screenshot-[01:02]", 62)])
def test_extract_accepts_legacy_formats(self):
markdown = "*Screenshot-03:04 and Screenshot-[05:06]"
matches = extract_screenshot_timestamps(markdown)
self.assertEqual(
matches,
[
("*Screenshot-03:04", 184),
("Screenshot-[05:06]", 306),
],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,42 @@
import importlib.util
import pathlib
import threading
import time
import unittest
ROOT = pathlib.Path(__file__).resolve().parents[1]
MODULE_PATH = ROOT / "app" / "services" / "task_serial_executor.py"
spec = importlib.util.spec_from_file_location("task_serial_executor", MODULE_PATH)
if spec is None or spec.loader is None:
raise ImportError("task_serial_executor module spec not found")
task_serial_executor = importlib.util.module_from_spec(spec)
spec.loader.exec_module(task_serial_executor)
SerialTaskExecutor = task_serial_executor.SerialTaskExecutor
class TestTaskSerialExecutor(unittest.TestCase):
def test_executor_runs_tasks_one_by_one(self):
executor = SerialTaskExecutor()
state_lock = threading.Lock()
state = {"active": 0, "peak_active": 0}
def critical_work():
with state_lock:
state["active"] += 1
state["peak_active"] = max(state["peak_active"], state["active"])
time.sleep(0.05)
with state_lock:
state["active"] -= 1
threads = [threading.Thread(target=lambda: executor.run(critical_work)) for _ in range(2)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(state["peak_active"], 1)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,147 @@
import importlib.util
import json
import os
import pathlib
import sys
import tempfile
import types
import unittest
from pathlib import Path
def _install_stubs():
app_mod = types.ModuleType("app")
gpt_pkg = types.ModuleType("app.gpt")
models_pkg = types.ModuleType("app.models")
base_mod = types.ModuleType("app.gpt.base")
class _GPT:
pass
base_mod.GPT = _GPT
prompt_builder_mod = types.ModuleType("app.gpt.prompt_builder")
def _generate_base_prompt(**_kwargs):
return "prompt"
prompt_builder_mod.generate_base_prompt = _generate_base_prompt
prompt_mod = types.ModuleType("app.gpt.prompt")
prompt_mod.BASE_PROMPT = ""
prompt_mod.AI_SUM = ""
prompt_mod.SCREENSHOT = ""
prompt_mod.LINK = ""
prompt_mod.MERGE_PROMPT = "merge"
utils_mod = types.ModuleType("app.gpt.utils")
def _fix_markdown(text):
return text
utils_mod.fix_markdown = _fix_markdown
request_chunker_mod = types.ModuleType("app.gpt.request_chunker")
class _RequestChunker:
def __init__(self, *_args, **_kwargs):
pass
def group_texts_by_budget(self, texts, _builder, **_kwargs):
return [texts]
request_chunker_mod.RequestChunker = _RequestChunker
gpt_model_mod = types.ModuleType("app.models.gpt_model")
class _GPTSource:
pass
gpt_model_mod.GPTSource = _GPTSource
transcriber_model_mod = types.ModuleType("app.models.transcriber_model")
class _TranscriptSegment:
def __init__(self, **kwargs):
self.start = kwargs.get("start", 0)
self.end = kwargs.get("end", 0)
self.text = kwargs.get("text", "")
transcriber_model_mod.TranscriptSegment = _TranscriptSegment
sys.modules.setdefault("app", app_mod)
sys.modules.setdefault("app.gpt", gpt_pkg)
sys.modules.setdefault("app.models", models_pkg)
sys.modules["app.gpt.base"] = base_mod
sys.modules["app.gpt.prompt_builder"] = prompt_builder_mod
sys.modules["app.gpt.prompt"] = prompt_mod
sys.modules["app.gpt.utils"] = utils_mod
sys.modules["app.gpt.request_chunker"] = request_chunker_mod
sys.modules["app.models.gpt_model"] = gpt_model_mod
sys.modules["app.models.transcriber_model"] = transcriber_model_mod
def _load_universal_gpt_class():
_install_stubs()
root = pathlib.Path(__file__).resolve().parents[1]
module_path = root / "app" / "gpt" / "universal_gpt.py"
spec = importlib.util.spec_from_file_location("universal_gpt", module_path)
if spec is None or spec.loader is None:
raise ImportError("universal_gpt module spec not found")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.UniversalGPT
UniversalGPT = _load_universal_gpt_class()
class _FailingCompletions:
def create(self, **_kwargs):
raise Exception("Error code: 524 - bad_response_status_code")
class _DummyChat:
def __init__(self):
self.completions = _FailingCompletions()
class _DummyModels:
@staticmethod
def list():
return []
class _DummyClient:
def __init__(self):
self.chat = _DummyChat()
self.models = _DummyModels()
class TestUniversalGPTCheckpoint(unittest.TestCase):
def test_merge_524_error_persists_checkpoint(self):
original_attempts = os.environ.get("OPENAI_RETRY_ATTEMPTS")
os.environ["OPENAI_RETRY_ATTEMPTS"] = "1"
gpt = UniversalGPT(_DummyClient(), model="mock-model")
try:
with tempfile.TemporaryDirectory() as tmp_dir:
gpt.checkpoint_dir = Path(tmp_dir)
with self.assertRaises(Exception):
gpt._merge_partials(["part-a", "part-b"], "task-1", "sig-1")
checkpoint_path = gpt._checkpoint_path("task-1")
self.assertTrue(checkpoint_path.exists())
payload = json.loads(checkpoint_path.read_text(encoding="utf-8"))
self.assertEqual(payload["phase"], "merge")
self.assertEqual(payload["partials"], ["part-a", "part-b"])
finally:
if original_attempts is None:
os.environ.pop("OPENAI_RETRY_ATTEMPTS", None)
else:
os.environ["OPENAI_RETRY_ATTEMPTS"] = original_attempts
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,142 @@
import importlib.util
import pathlib
import re
import sys
import tempfile
import types
import unittest
from unittest.mock import patch
def _install_stubs():
app_mod = types.ModuleType("app")
utils_pkg = types.ModuleType("app.utils")
logger_mod = types.ModuleType("app.utils.logger")
class _Logger:
@staticmethod
def info(*_args, **_kwargs):
return None
@staticmethod
def warning(*_args, **_kwargs):
return None
@staticmethod
def error(*_args, **_kwargs):
return None
def _get_logger(_name):
return _Logger()
logger_mod.get_logger = _get_logger
path_helper_mod = types.ModuleType("app.utils.path_helper")
ffmpeg_mod = types.ModuleType("ffmpeg")
pil_mod = types.ModuleType("PIL")
pil_image_mod = types.ModuleType("PIL.Image")
pil_draw_mod = types.ModuleType("PIL.ImageDraw")
pil_font_mod = types.ModuleType("PIL.ImageFont")
class _FakeImage:
pass
class _FakeImageDraw:
@staticmethod
def Draw(*_args, **_kwargs):
return None
class _FakeImageFont:
@staticmethod
def truetype(*_args, **_kwargs):
return None
@staticmethod
def load_default():
return None
pil_image_mod.Image = _FakeImage
pil_draw_mod.ImageDraw = _FakeImageDraw
pil_font_mod.ImageFont = _FakeImageFont
def _get_app_dir(name):
return name
path_helper_mod.get_app_dir = _get_app_dir
ffmpeg_mod.probe = lambda *_args, **_kwargs: {"format": {"duration": "0"}}
sys.modules.setdefault("app", app_mod)
sys.modules.setdefault("app.utils", utils_pkg)
sys.modules["PIL"] = pil_mod
sys.modules["PIL.Image"] = pil_image_mod
sys.modules["PIL.ImageDraw"] = pil_draw_mod
sys.modules["PIL.ImageFont"] = pil_font_mod
sys.modules["ffmpeg"] = ffmpeg_mod
sys.modules["app.utils.logger"] = logger_mod
sys.modules["app.utils.path_helper"] = path_helper_mod
def _load_video_reader_module():
_install_stubs()
root = pathlib.Path(__file__).resolve().parents[1]
module_path = root / "app" / "utils" / "video_reader.py"
spec = importlib.util.spec_from_file_location("video_reader", module_path)
if spec is None or spec.loader is None:
raise ImportError("video_reader module spec not found")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
video_reader_module = _load_video_reader_module()
VideoReader = video_reader_module.VideoReader
def _make_fake_ffmpeg_runner(colors_by_second):
def _runner(cmd, check=True):
output_path = next((arg for arg in cmd if isinstance(arg, str) and arg.endswith(".jpg")), None)
if output_path is None:
raise AssertionError("Output path not found in ffmpeg cmd")
match = re.search(r"frame_(\d{2})_(\d{2})\.jpg$", output_path)
if match is None:
raise AssertionError("Unexpected output path")
sec = int(match.group(1)) * 60 + int(match.group(2))
payload = colors_by_second[sec]
with open(output_path, "wb") as f:
f.write(payload)
return 0
return _runner
class TestVideoReaderDeduplicateFrames(unittest.TestCase):
def test_extract_frames_skips_adjacent_duplicates_when_enabled(self):
with tempfile.TemporaryDirectory() as tmp_dir:
frame_dir = pathlib.Path(tmp_dir) / "frames"
grid_dir = pathlib.Path(tmp_dir) / "grids"
reader = VideoReader(
video_path="dummy.mp4",
frame_interval=1,
frame_dir=str(frame_dir),
grid_dir=str(grid_dir),
)
fake_colors = {
0: b"frame-a",
1: b"frame-a",
2: b"frame-b",
3: b"frame-b",
}
with patch.object(video_reader_module.ffmpeg, "probe", return_value={"format": {"duration": "4"}}), \
patch.object(video_reader_module.subprocess, "run", side_effect=_make_fake_ffmpeg_runner(fake_colors)):
paths = reader.extract_frames(max_frames=10)
names = [pathlib.Path(p).name for p in paths]
self.assertEqual(names, ["frame_00_00.jpg", "frame_00_02.jpg"])
if __name__ == "__main__":
unittest.main()