mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-12 02:20:18 +08:00
🐞 fix: 增加错误之后对已解析段落的缓存功能,再次重试时不再重头开始
解析长视频时,当附件大小过大时不再调用后进行报错,而是将附件进行分批次发送 在每篇笔记开头默认增加地址来源链接,对模糊处可溯源
This commit is contained in:
@@ -18,12 +18,12 @@ BASE_PROMPT = '''
|
||||
- **不要**将输出包裹在代码块中(例如:```` ```markdown ````,```` ``` ````)。
|
||||
请注意,在生成 Markdown 时,避免将编号标题(如“1. **内容**”)写成有序列表的格式,以免解析错误。
|
||||
|
||||
- 如果要加粗并保留编号,应使用 `1\. **内容**`(加反斜杠),防止被误解析为有序列表。
|
||||
- 如果要加粗并保留编号,应使用 `1\\. **内容**`(加反斜杠),防止被误解析为有序列表。
|
||||
- 或者使用 `## 1. 内容` 的形式作为标题。
|
||||
|
||||
请确保以下格式 **不会出现误渲染**:
|
||||
`1. **xxx**`
|
||||
`1\. **xxx**` 或 `## 1. xxx`
|
||||
`1\\. **xxx**` 或 `## 1. xxx`
|
||||
|
||||
视频分段(格式:开始时间 - 内容):
|
||||
|
||||
@@ -66,4 +66,13 @@ SCREENSHOT='''
|
||||
8. **Screenshot placeholders**: If a section involves **visual demonstrations, code walkthroughs, UI interactions**, or any content where visuals aid understanding, insert a screenshot cue at the end of that section:
|
||||
- Format: `*Screenshot-[mm:ss]`
|
||||
- Only use it when truly helpful.
|
||||
'''
|
||||
'''
|
||||
|
||||
MERGE_PROMPT = '''
|
||||
你将收到多个来自同一视频的 Markdown 笔记片段,请合并成一份完整笔记:
|
||||
- 只做合并与去重,不要发明新内容
|
||||
- 保持原有标题层级与 Markdown 结构
|
||||
- 保留所有 *Content-[mm:ss] 与 *Screenshot-[mm:ss] 标记
|
||||
- 保持中文输出,专有名词保留英文
|
||||
- 不要使用代码块包裹输出
|
||||
'''
|
||||
|
||||
161
backend/app/gpt/request_chunker.py
Normal file
161
backend/app/gpt/request_chunker.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkPayload:
|
||||
segments: list
|
||||
image_urls: list
|
||||
|
||||
|
||||
class RequestChunker:
|
||||
def __init__(self, message_builder: Callable, max_bytes: int, size_estimator: Optional[Callable] = None):
|
||||
self.message_builder = message_builder
|
||||
self.max_bytes = max_bytes
|
||||
self.size_estimator = size_estimator
|
||||
|
||||
def estimate(self, messages) -> int:
|
||||
if self.size_estimator:
|
||||
return self.size_estimator(messages)
|
||||
import json
|
||||
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
|
||||
|
||||
def _messages_size(self, segments, image_urls, **kwargs) -> int:
|
||||
messages = self.message_builder(segments, image_urls, **kwargs)
|
||||
return self.estimate(messages)
|
||||
|
||||
def _get_text(self, segment) -> str:
|
||||
if isinstance(segment, dict):
|
||||
return segment.get("text", "")
|
||||
return getattr(segment, "text", "")
|
||||
|
||||
def _make_segment(self, segment, text: str):
|
||||
if isinstance(segment, dict):
|
||||
new_seg = dict(segment)
|
||||
new_seg["text"] = text
|
||||
return new_seg
|
||||
if hasattr(segment, "__dict__"):
|
||||
data = dict(segment.__dict__)
|
||||
data["text"] = text
|
||||
return type(segment)(**data)
|
||||
return type(segment)(segment.start, segment.end, text)
|
||||
|
||||
def _split_segment_to_fit(self, segment, **kwargs):
|
||||
text = self._get_text(segment)
|
||||
if not text:
|
||||
raise ValueError("empty segment cannot be split")
|
||||
lo, hi = 1, len(text)
|
||||
best = None
|
||||
while lo <= hi:
|
||||
mid = (lo + hi) // 2
|
||||
candidate = self._make_segment(segment, text[:mid])
|
||||
size = self._messages_size([candidate], [], **kwargs)
|
||||
if size <= self.max_bytes:
|
||||
best = mid
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid - 1
|
||||
if best is None:
|
||||
raise ValueError("single segment too large to fit request")
|
||||
head = self._make_segment(segment, text[:best])
|
||||
tail = self._make_segment(segment, text[best:])
|
||||
return head, tail
|
||||
|
||||
def chunk(self, segments: list, image_urls: list, **kwargs) -> List[ChunkPayload]:
|
||||
segments = list(segments or [])
|
||||
image_urls = list(image_urls or [])
|
||||
if not segments and not image_urls:
|
||||
return []
|
||||
|
||||
chunks: List[ChunkPayload] = []
|
||||
seg_idx = 0
|
||||
|
||||
while seg_idx < len(segments):
|
||||
batch_segments = []
|
||||
while seg_idx < len(segments):
|
||||
candidate = batch_segments + [segments[seg_idx]]
|
||||
size = self._messages_size(candidate, [], **kwargs)
|
||||
if size <= self.max_bytes:
|
||||
batch_segments = candidate
|
||||
seg_idx += 1
|
||||
continue
|
||||
if not batch_segments:
|
||||
head, tail = self._split_segment_to_fit(segments[seg_idx], **kwargs)
|
||||
segments[seg_idx] = head
|
||||
segments.insert(seg_idx + 1, tail)
|
||||
continue
|
||||
break
|
||||
|
||||
if not batch_segments:
|
||||
raise ValueError("unable to fit any content into chunk")
|
||||
|
||||
chunks.append(ChunkPayload(segments=batch_segments, image_urls=[]))
|
||||
|
||||
if not image_urls:
|
||||
return chunks
|
||||
|
||||
if not chunks:
|
||||
chunks = [ChunkPayload(segments=[], image_urls=[])]
|
||||
|
||||
if not segments:
|
||||
for image in image_urls:
|
||||
appended = False
|
||||
for chunk in chunks[-1:]:
|
||||
candidate_images = chunk.image_urls + [image]
|
||||
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
|
||||
chunk.image_urls = candidate_images
|
||||
appended = True
|
||||
break
|
||||
|
||||
if appended:
|
||||
continue
|
||||
|
||||
if self._messages_size([], [image], **kwargs) > self.max_bytes:
|
||||
raise ValueError("single image payload exceeds max_bytes")
|
||||
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
|
||||
return chunks
|
||||
|
||||
chunk_count = len(chunks)
|
||||
total_images = len(image_urls)
|
||||
for idx, image in enumerate(image_urls):
|
||||
preferred_idx = min(chunk_count - 1, (idx * chunk_count) // total_images)
|
||||
placed = False
|
||||
|
||||
for chunk_idx in range(preferred_idx, len(chunks)):
|
||||
chunk = chunks[chunk_idx]
|
||||
candidate_images = chunk.image_urls + [image]
|
||||
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
|
||||
chunk.image_urls = candidate_images
|
||||
placed = True
|
||||
break
|
||||
|
||||
if placed:
|
||||
continue
|
||||
|
||||
if self._messages_size([], [image], **kwargs) > self.max_bytes:
|
||||
raise ValueError("single image payload exceeds max_bytes")
|
||||
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
|
||||
|
||||
return chunks
|
||||
|
||||
def group_texts_by_budget(self, texts: List[str], build_messages: Callable, **kwargs) -> List[List[str]]:
|
||||
groups: List[List[str]] = []
|
||||
idx = 0
|
||||
while idx < len(texts):
|
||||
group: List[str] = []
|
||||
while idx < len(texts):
|
||||
candidate = group + [texts[idx]]
|
||||
try:
|
||||
messages = build_messages(candidate, [], **kwargs)
|
||||
except TypeError:
|
||||
messages = build_messages(candidate, **kwargs)
|
||||
size = self.estimate(messages)
|
||||
if size <= self.max_bytes:
|
||||
group = candidate
|
||||
idx += 1
|
||||
continue
|
||||
if not group:
|
||||
raise ValueError("single text block exceeds max_bytes")
|
||||
break
|
||||
groups.append(group)
|
||||
return groups
|
||||
@@ -1,8 +1,16 @@
|
||||
from app.gpt.base import GPT
|
||||
from app.gpt.prompt_builder import generate_base_prompt
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK, MERGE_PROMPT
|
||||
from app.gpt.utils import fix_markdown
|
||||
from app.gpt.request_chunker import RequestChunker
|
||||
from app.models.transcriber_model import TranscriptSegment
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
@@ -15,6 +23,9 @@ class UniversalGPT(GPT):
|
||||
self.temperature = temperature
|
||||
self.screenshot = False
|
||||
self.link = False
|
||||
self.max_request_bytes = int(os.getenv("OPENAI_MAX_REQUEST_BYTES", str(45 * 1024 * 1024)))
|
||||
self.checkpoint_dir = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
return str(timedelta(seconds=int(seconds)))[2:]
|
||||
@@ -40,7 +51,7 @@ class UniversalGPT(GPT):
|
||||
)
|
||||
|
||||
# ⛳ 组装 content 数组,支持 text + image_url 混合
|
||||
content = [{"type": "text", "text": content_text}]
|
||||
content: List[dict] = [{"type": "text", "text": content_text}]
|
||||
video_img_urls = kwargs.get('video_img_urls', [])
|
||||
|
||||
for url in video_img_urls:
|
||||
@@ -63,23 +74,234 @@ class UniversalGPT(GPT):
|
||||
def list_models(self):
|
||||
return self.client.models.list()
|
||||
|
||||
def _estimate_messages_bytes(self, messages: list) -> int:
|
||||
import json
|
||||
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
|
||||
|
||||
def _build_merge_messages(self, partials: list) -> list:
|
||||
merge_text = MERGE_PROMPT + "\n\n" + "\n\n---\n\n".join(partials)
|
||||
return [{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": merge_text}]
|
||||
}]
|
||||
|
||||
def _checkpoint_path(self, checkpoint_key: str) -> Path:
|
||||
safe_key = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in checkpoint_key)
|
||||
return self.checkpoint_dir / f"{safe_key}.gpt.checkpoint.json"
|
||||
|
||||
def _build_source_signature(self, source: GPTSource) -> str:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"max_request_bytes": self.max_request_bytes,
|
||||
"title": source.title,
|
||||
"tags": source.tags,
|
||||
"format": source._format,
|
||||
"style": source.style,
|
||||
"extras": source.extras,
|
||||
"video_img_urls": source.video_img_urls or [],
|
||||
"segments": [
|
||||
{
|
||||
"start": getattr(seg, "start", None),
|
||||
"end": getattr(seg, "end", None),
|
||||
"text": getattr(seg, "text", "")
|
||||
}
|
||||
for seg in source.segment
|
||||
],
|
||||
}
|
||||
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
def _load_checkpoint(self, checkpoint_key: str, source_signature: str) -> dict | None:
|
||||
path = self._checkpoint_path(checkpoint_key)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if data.get("source_signature") != source_signature:
|
||||
path.unlink(missing_ok=True)
|
||||
return None
|
||||
return data
|
||||
except Exception:
|
||||
path.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
def _save_checkpoint(self, checkpoint_key: str, source_signature: str, partials: list, phase: str) -> None:
|
||||
path = self._checkpoint_path(checkpoint_key)
|
||||
data = {
|
||||
"version": 1,
|
||||
"source_signature": source_signature,
|
||||
"phase": phase,
|
||||
"partials": partials,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
tmp_path = path.with_suffix(".tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _clear_checkpoint(self, checkpoint_key: str) -> None:
|
||||
self._checkpoint_path(checkpoint_key).unlink(missing_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _is_insufficient_quota_error(exc: Exception) -> bool:
|
||||
raw = str(exc)
|
||||
return (
|
||||
"insufficient_user_quota" in raw
|
||||
or "预扣费额度失败" in raw
|
||||
or "insufficient quota" in raw.lower()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_error(exc: Exception) -> bool:
|
||||
raw = str(exc).lower()
|
||||
retryable_tokens = (
|
||||
"error code: 524",
|
||||
"bad_response_status_code",
|
||||
"timed out",
|
||||
"timeout",
|
||||
"rate limit",
|
||||
"error code: 429",
|
||||
"error code: 500",
|
||||
"error code: 502",
|
||||
"error code: 503",
|
||||
"error code: 504",
|
||||
"apiconnectionerror",
|
||||
"connection error",
|
||||
"service unavailable",
|
||||
)
|
||||
if any(token in raw for token in retryable_tokens):
|
||||
return True
|
||||
|
||||
status = getattr(exc, "status_code", None) or getattr(exc, "status", None)
|
||||
return status in {408, 409, 429, 500, 502, 503, 504, 524}
|
||||
|
||||
def _chat_completion_create(self, messages: list):
|
||||
max_attempts = max(1, int(os.getenv("OPENAI_RETRY_ATTEMPTS", "3")))
|
||||
base_backoff = float(os.getenv("OPENAI_RETRY_BACKOFF_SECONDS", "1.5"))
|
||||
|
||||
last_exc = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=self.temperature
|
||||
)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt == max_attempts - 1 or not self._is_retryable_error(exc):
|
||||
raise
|
||||
sleep_seconds = base_backoff * (2 ** attempt)
|
||||
time.sleep(sleep_seconds)
|
||||
|
||||
if last_exc is not None:
|
||||
raise last_exc
|
||||
raise RuntimeError("chat completion failed without exception")
|
||||
|
||||
def _merge_partials(self, partials: list, checkpoint_key: str | None, source_signature: str | None) -> str:
|
||||
def build_messages(texts, *_args, **_kwargs):
|
||||
return self._build_merge_messages(texts)
|
||||
|
||||
merge_chunker = RequestChunker(
|
||||
lambda *_args, **_kwargs: [],
|
||||
self.max_request_bytes,
|
||||
self._estimate_messages_bytes
|
||||
)
|
||||
|
||||
current_partials = list(partials)
|
||||
while len(current_partials) > 1:
|
||||
groups = merge_chunker.group_texts_by_budget(current_partials, build_messages)
|
||||
new_partials = []
|
||||
for group_idx, group in enumerate(groups):
|
||||
messages = build_messages(group)
|
||||
try:
|
||||
response = self._chat_completion_create(messages)
|
||||
except Exception as exc:
|
||||
if checkpoint_key and source_signature:
|
||||
self._save_checkpoint(checkpoint_key, source_signature, current_partials, "merge")
|
||||
raise
|
||||
|
||||
new_partials.append(response.choices[0].message.content.strip())
|
||||
|
||||
if checkpoint_key and source_signature:
|
||||
remaining_partials = []
|
||||
for remaining_group in groups[group_idx + 1:]:
|
||||
remaining_partials.extend(remaining_group)
|
||||
resumable_partials = new_partials + remaining_partials
|
||||
self._save_checkpoint(checkpoint_key, source_signature, resumable_partials, "merge")
|
||||
|
||||
current_partials = new_partials
|
||||
|
||||
return current_partials[0]
|
||||
|
||||
def summarize(self, source: GPTSource) -> str:
|
||||
self.screenshot = source.screenshot
|
||||
self.link = source.link
|
||||
source.segment = self.ensure_segments_type(source.segment)
|
||||
checkpoint_key = source.checkpoint_key
|
||||
source_signature = self._build_source_signature(source) if checkpoint_key else None
|
||||
|
||||
messages = self.create_messages(
|
||||
source.segment,
|
||||
title=source.title,
|
||||
tags=source.tags,
|
||||
video_img_urls=source.video_img_urls,
|
||||
_format=source._format,
|
||||
style=source.style,
|
||||
extras=source.extras
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.7
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
def message_builder(segments, image_urls, **kwargs):
|
||||
return self.create_messages(segments, video_img_urls=image_urls, **kwargs)
|
||||
|
||||
chunker = RequestChunker(message_builder, self.max_request_bytes, self._estimate_messages_bytes)
|
||||
|
||||
try:
|
||||
chunks = chunker.chunk(
|
||||
source.segment,
|
||||
source.video_img_urls or [],
|
||||
title=source.title,
|
||||
tags=source.tags,
|
||||
_format=source._format,
|
||||
style=source.style,
|
||||
extras=source.extras
|
||||
)
|
||||
except ValueError:
|
||||
chunks = chunker.chunk(
|
||||
source.segment,
|
||||
[],
|
||||
title=source.title,
|
||||
tags=source.tags,
|
||||
_format=source._format,
|
||||
style=source.style,
|
||||
extras=source.extras
|
||||
)
|
||||
|
||||
partials = []
|
||||
if checkpoint_key and source_signature:
|
||||
checkpoint = self._load_checkpoint(checkpoint_key, source_signature)
|
||||
if checkpoint and isinstance(checkpoint.get("partials"), list):
|
||||
partials = checkpoint["partials"]
|
||||
|
||||
if len(partials) > len(chunks):
|
||||
partials = []
|
||||
|
||||
for chunk in chunks[len(partials):]:
|
||||
messages = self.create_messages(
|
||||
chunk.segments,
|
||||
title=source.title,
|
||||
tags=source.tags,
|
||||
video_img_urls=chunk.image_urls,
|
||||
_format=source._format,
|
||||
style=source.style,
|
||||
extras=source.extras
|
||||
)
|
||||
try:
|
||||
response = self._chat_completion_create(messages)
|
||||
except Exception as exc:
|
||||
if checkpoint_key and source_signature:
|
||||
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
|
||||
raise
|
||||
|
||||
partials.append(response.choices[0].message.content.strip())
|
||||
if checkpoint_key and source_signature:
|
||||
self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize")
|
||||
|
||||
if len(partials) == 1:
|
||||
if checkpoint_key:
|
||||
self._clear_checkpoint(checkpoint_key)
|
||||
return partials[0]
|
||||
merged = self._merge_partials(partials, checkpoint_key, source_signature)
|
||||
if checkpoint_key:
|
||||
self._clear_checkpoint(checkpoint_key)
|
||||
return merged
|
||||
|
||||
@@ -15,4 +15,5 @@ class GPTSource:
|
||||
extras: Optional[str] = None
|
||||
_format: Optional[list] = None
|
||||
video_img_urls: Optional[list] = None
|
||||
checkpoint_key: Optional[str] = None
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.enmus.exception import NoteErrorEnum
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.exceptions.note import NoteError
|
||||
from app.services.note import NoteGenerator, logger
|
||||
from app.services.task_serial_executor import task_serial_executor
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
from app.utils.url_parser import extract_video_id
|
||||
from app.validators.video_url_validator import is_supported_video_url
|
||||
@@ -82,22 +83,26 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
|
||||
if not model_name or not provider_id:
|
||||
raise HTTPException(status_code=400, detail="请选择模型和提供者")
|
||||
|
||||
note = NoteGenerator().generate(
|
||||
video_url=video_url,
|
||||
platform=platform,
|
||||
quality=quality,
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras,
|
||||
screenshot=screenshot
|
||||
, video_understanding=video_understanding,
|
||||
video_interval=video_interval,
|
||||
grid_size=grid_size
|
||||
)
|
||||
def _execute_note_task():
|
||||
return NoteGenerator().generate(
|
||||
video_url=video_url,
|
||||
platform=platform,
|
||||
quality=quality,
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras,
|
||||
screenshot=screenshot,
|
||||
video_understanding=video_understanding,
|
||||
video_interval=video_interval,
|
||||
grid_size=grid_size,
|
||||
)
|
||||
|
||||
logger.info(f"任务进入串行队列,等待执行 (task_id={task_id})")
|
||||
note = task_serial_executor.run(_execute_note_task)
|
||||
logger.info(f"Note generated: {task_id}")
|
||||
if not note or not note.markdown:
|
||||
logger.warning(f"任务 {task_id} 执行失败,跳过保存")
|
||||
@@ -144,13 +149,14 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
|
||||
if data.task_id:
|
||||
# 如果传了task_id,说明是重试!
|
||||
task_id = data.task_id
|
||||
# 更新之前的状态
|
||||
NoteGenerator()._update_status(task_id, TaskStatus.PENDING)
|
||||
logger.info(f"重试模式,复用已有 task_id={task_id}")
|
||||
else:
|
||||
# 正常新建任务
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 统一先写入 PENDING,表示已进入队列等待串行执行
|
||||
NoteGenerator()._update_status(task_id, TaskStatus.PENDING)
|
||||
|
||||
background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality, data.link,
|
||||
data.screenshot, data.model_name, data.provider_id, data.format, data.style,
|
||||
data.extras, data.video_understanding, data.video_interval, data.grid_size)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, Any
|
||||
@@ -32,7 +31,8 @@ from app.services.constant import SUPPORT_PLATFORM_MAP
|
||||
from app.services.provider import ProviderService
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
|
||||
from app.utils.note_helper import replace_content_markers
|
||||
from app.utils.note_helper import replace_content_markers, prepend_source_link
|
||||
from app.utils.screenshot_marker import extract_screenshot_timestamps
|
||||
from app.utils.status_code import StatusCode
|
||||
from app.utils.video_helper import generate_screenshot
|
||||
from app.utils.video_reader import VideoReader
|
||||
@@ -182,6 +182,8 @@ class NoteGenerator:
|
||||
platform=platform,
|
||||
)
|
||||
|
||||
markdown = prepend_source_link(markdown, str(video_url))
|
||||
|
||||
# 5. 保存记录到数据库
|
||||
self._update_status(task_id, TaskStatus.SAVING)
|
||||
self._save_metadata(video_id=audio_meta.video_id, platform=platform, task_id=task_id)
|
||||
@@ -353,6 +355,10 @@ class NoteGenerator:
|
||||
|
||||
# 判断是否需要下载视频
|
||||
need_video = screenshot or video_understanding
|
||||
if screenshot and not grid_size:
|
||||
grid_size = [2, 2]
|
||||
|
||||
frame_interval = video_interval if video_interval and video_interval > 0 else 6
|
||||
if need_video:
|
||||
try:
|
||||
logger.info("开始下载视频")
|
||||
@@ -365,10 +371,10 @@ class NoteGenerator:
|
||||
self.video_img_urls=VideoReader(
|
||||
video_path=str(self.video_path),
|
||||
grid_size=tuple(grid_size),
|
||||
frame_interval=video_interval,
|
||||
unit_width=1280,
|
||||
unit_height=720,
|
||||
save_quality=90,
|
||||
frame_interval=frame_interval,
|
||||
unit_width=960,
|
||||
unit_height=540,
|
||||
save_quality=80,
|
||||
).run()
|
||||
else:
|
||||
logger.info("未指定 grid_size,跳过缩略图生成")
|
||||
@@ -540,6 +546,7 @@ class NoteGenerator:
|
||||
_format=formats,
|
||||
style=style,
|
||||
extras=extras,
|
||||
checkpoint_key=task_id,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -592,7 +599,7 @@ class NoteGenerator:
|
||||
:param video_path: 本地视频文件路径
|
||||
:return: 替换后的 Markdown 字符串
|
||||
"""
|
||||
matches: List[Tuple[str, int]] = self._extract_screenshot_timestamps(markdown)
|
||||
matches: List[Tuple[str, int]] = extract_screenshot_timestamps(markdown)
|
||||
for idx, (marker, ts) in enumerate(matches):
|
||||
try:
|
||||
img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx)
|
||||
@@ -615,14 +622,7 @@ class NoteGenerator:
|
||||
:param markdown: 原始 Markdown 文本
|
||||
:return: 标记与对应时间戳秒数的列表
|
||||
"""
|
||||
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
|
||||
results: List[Tuple[str, int]] = []
|
||||
for match in re.finditer(pattern, markdown):
|
||||
mm = match.group(1) or match.group(3)
|
||||
ss = match.group(2) or match.group(4)
|
||||
total_seconds = int(mm) * 60 + int(ss)
|
||||
results.append((match.group(0), total_seconds))
|
||||
return results
|
||||
return extract_screenshot_timestamps(markdown)
|
||||
|
||||
def _save_metadata(self, video_id: str, platform: str, task_id: str) -> None:
|
||||
"""
|
||||
@@ -636,4 +636,4 @@ class NoteGenerator:
|
||||
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
|
||||
logger.info(f"已保存任务记录到数据库 (video_id={video_id}, platform={platform}, task_id={task_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"保存任务记录失败:{e}")
|
||||
logger.error(f"保存任务记录失败:{e}")
|
||||
|
||||
14
backend/app/services/task_serial_executor.py
Normal file
14
backend/app/services/task_serial_executor.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class SerialTaskExecutor:
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def run(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
with self._lock:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
task_serial_executor = SerialTaskExecutor()
|
||||
@@ -5,6 +5,37 @@ import re
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def prepend_source_link(markdown: str | None, source_url: str) -> str | None:
|
||||
"""
|
||||
在笔记开头添加来源链接;若首个非空行已包含来源链接,则更新该行并避免重复。
|
||||
"""
|
||||
if markdown is None:
|
||||
return None
|
||||
|
||||
source = (source_url or "").strip()
|
||||
if not source:
|
||||
return markdown
|
||||
|
||||
header = f"> 来源链接:{source}"
|
||||
lines = markdown.splitlines()
|
||||
first_non_empty_idx = None
|
||||
for idx, line in enumerate(lines):
|
||||
if line.strip():
|
||||
first_non_empty_idx = idx
|
||||
break
|
||||
|
||||
if first_non_empty_idx is not None:
|
||||
first_line = lines[first_non_empty_idx].strip()
|
||||
if first_line.startswith("> 来源链接:") or first_line.startswith("来源链接:"):
|
||||
lines[first_non_empty_idx] = header
|
||||
return "\n".join(lines)
|
||||
|
||||
if markdown.strip():
|
||||
return f"{header}\n\n{markdown}"
|
||||
return header
|
||||
|
||||
|
||||
def replace_content_markers(markdown: str, video_id: str, platform: str = 'bilibili') -> str:
|
||||
"""
|
||||
替换 *Content-04:16*、Content-04:16 或 Content-[04:16] 为超链接,跳转到对应平台视频的时间位置
|
||||
@@ -12,18 +43,20 @@ def replace_content_markers(markdown: str, video_id: str, platform: str = 'bilib
|
||||
# 匹配三种形式:*Content-04:16*、Content-04:16、Content-[04:16]
|
||||
pattern = r"(?:\*?)Content-(?:\[(\d{2}):(\d{2})\]|(\d{2}):(\d{2}))"
|
||||
|
||||
safe_video_id = video_id
|
||||
|
||||
def replacer(match):
|
||||
mm = match.group(1) or match.group(3)
|
||||
ss = match.group(2) or match.group(4)
|
||||
total_seconds = int(mm) * 60 + int(ss)
|
||||
|
||||
if platform == 'bilibili':
|
||||
video_id = video_id.replace("_p", "?p=")
|
||||
url = f"https://www.bilibili.com/video/{video_id}&t={total_seconds}"
|
||||
parsed_video_id = safe_video_id.replace("_p", "?p=")
|
||||
url = f"https://www.bilibili.com/video/{parsed_video_id}&t={total_seconds}"
|
||||
elif platform == 'youtube':
|
||||
url = f"https://www.youtube.com/watch?v={video_id}&t={total_seconds}s"
|
||||
url = f"https://www.youtube.com/watch?v={safe_video_id}&t={total_seconds}s"
|
||||
elif platform == 'douyin':
|
||||
url = f"https://www.douyin.com/video/{video_id}"
|
||||
url = f"https://www.douyin.com/video/{safe_video_id}"
|
||||
return f"[原片 @ {mm}:{ss}]({url})"
|
||||
else:
|
||||
return f"({mm}:{ss})"
|
||||
|
||||
13
backend/app/utils/screenshot_marker.py
Normal file
13
backend/app/utils/screenshot_marker.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
def extract_screenshot_timestamps(markdown: str) -> List[Tuple[str, int]]:
|
||||
pattern = r"(\*?Screenshot-(?:\[(\d{2}):(\d{2})\]|(\d{2}):(\d{2})))"
|
||||
results: List[Tuple[str, int]] = []
|
||||
for match in re.finditer(pattern, markdown):
|
||||
mm = match.group(2) or match.group(4)
|
||||
ss = match.group(3) or match.group(5)
|
||||
total_seconds = int(mm) * 60 + int(ss)
|
||||
results.append((match.group(1), total_seconds))
|
||||
return results
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
@@ -14,6 +15,7 @@ class VideoReader:
|
||||
video_path: str,
|
||||
grid_size=(3, 3),
|
||||
frame_interval=2,
|
||||
dedupe_enabled=True,
|
||||
unit_width=960,
|
||||
unit_height=540,
|
||||
save_quality=90,
|
||||
@@ -23,6 +25,7 @@ class VideoReader:
|
||||
self.video_path = video_path
|
||||
self.grid_size = grid_size
|
||||
self.frame_interval = frame_interval
|
||||
self.dedupe_enabled = dedupe_enabled
|
||||
self.unit_width = unit_width
|
||||
self.unit_height = unit_height
|
||||
self.save_quality = save_quality
|
||||
@@ -31,6 +34,14 @@ class VideoReader:
|
||||
print(f"视频路径:{video_path}",self.frame_dir,self.grid_dir)
|
||||
self.font_path = font_path
|
||||
|
||||
@staticmethod
|
||||
def _calculate_file_md5(file_path: str) -> str:
|
||||
hasher = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
def format_time(self, seconds: float) -> str:
|
||||
mm = int(seconds // 60)
|
||||
ss = int(seconds % 60)
|
||||
@@ -51,12 +62,21 @@ class VideoReader:
|
||||
timestamps = [i for i in range(0, int(duration), self.frame_interval)][:max_frames]
|
||||
|
||||
image_paths = []
|
||||
last_hash = None
|
||||
for ts in timestamps:
|
||||
time_label = self.format_time(ts)
|
||||
output_path = os.path.join(self.frame_dir, f"frame_{time_label}.jpg")
|
||||
cmd = ["ffmpeg", "-ss", str(ts), "-i", self.video_path, "-frames:v", "1", "-q:v", "2", "-y", output_path,
|
||||
"-hide_banner", "-loglevel", "error"]
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
if self.dedupe_enabled:
|
||||
frame_hash = self._calculate_file_md5(output_path)
|
||||
if frame_hash == last_hash:
|
||||
os.remove(output_path)
|
||||
continue
|
||||
last_hash = frame_hash
|
||||
|
||||
image_paths.append(output_path)
|
||||
return image_paths
|
||||
except Exception as e:
|
||||
|
||||
1
backend/run.bat
Normal file
1
backend/run.bat
Normal file
@@ -0,0 +1 @@
|
||||
python main.py
|
||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
35
backend/tests/test_note_helper.py
Normal file
35
backend/tests/test_note_helper.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import importlib.util
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
|
||||
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "app" / "utils" / "note_helper.py"
|
||||
spec = importlib.util.spec_from_file_location("note_helper", MODULE_PATH)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError("note_helper module spec not found")
|
||||
note_helper = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(note_helper)
|
||||
|
||||
|
||||
class TestNoteHelper(unittest.TestCase):
|
||||
def test_prepend_source_link_adds_header_at_top(self):
|
||||
source_url = "https://www.bilibili.com/video/BV1xx411c7mD"
|
||||
markdown = "## 标题\n\n内容"
|
||||
|
||||
result = note_helper.prepend_source_link(markdown, source_url)
|
||||
|
||||
self.assertTrue(result.startswith(f"> 来源链接:{source_url}\n\n"))
|
||||
self.assertIn("## 标题", result)
|
||||
|
||||
def test_prepend_source_link_does_not_duplicate_when_header_exists(self):
|
||||
source_url = "https://www.youtube.com/watch?v=abc123"
|
||||
markdown = f"> 来源链接:{source_url}\n\n## 标题\n\n内容"
|
||||
|
||||
result = note_helper.prepend_source_link(markdown, source_url)
|
||||
|
||||
self.assertEqual(result, markdown)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
97
backend/tests/test_request_chunker.py
Normal file
97
backend/tests/test_request_chunker.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import importlib.util
|
||||
import pathlib
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "app" / "gpt" / "request_chunker.py"
|
||||
spec = importlib.util.spec_from_file_location("request_chunker", MODULE_PATH)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError("request_chunker module spec not found")
|
||||
request_chunker = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(request_chunker)
|
||||
RequestChunker = request_chunker.RequestChunker
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySeg:
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
def build_messages(segments, image_urls, **_):
|
||||
content = [{"type": "text", "text": "".join(s.text for s in segments)}]
|
||||
for url in image_urls:
|
||||
content.append({"type": "image_url", "image_url": {"url": url, "detail": "auto"}})
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
|
||||
def size_estimator(messages):
|
||||
size = 0
|
||||
for part in messages[0]["content"]:
|
||||
if part["type"] == "text":
|
||||
size += len(part["text"])
|
||||
else:
|
||||
size += len(part["image_url"]["url"])
|
||||
return size
|
||||
|
||||
|
||||
class TestRequestChunker(unittest.TestCase):
|
||||
def test_chunk_segments_preserves_order_and_content(self):
|
||||
segments = [
|
||||
DummySeg(0, 1, "aaaa"),
|
||||
DummySeg(1, 2, "bbbb"),
|
||||
DummySeg(2, 3, "cccc"),
|
||||
]
|
||||
chunker = RequestChunker(build_messages, max_bytes=8, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, [])
|
||||
texts = ["".join(seg.text for seg in c.segments) for c in chunks]
|
||||
self.assertEqual("".join(texts), "aaaabbbbcccc")
|
||||
self.assertTrue(all(texts))
|
||||
|
||||
def test_chunk_images_distributed_across_batches(self):
|
||||
segments = [DummySeg(0, 1, "aa")]
|
||||
images = ["i" * 6, "j" * 6, "k" * 6]
|
||||
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, images)
|
||||
all_images = [img for c in chunks for img in c.image_urls]
|
||||
self.assertEqual(all_images, images)
|
||||
|
||||
def test_chunk_images_are_not_front_loaded_when_multiple_segment_chunks(self):
|
||||
segments = [
|
||||
DummySeg(0, 1, "aaaaaa"),
|
||||
DummySeg(1, 2, "bbbbbb"),
|
||||
DummySeg(2, 3, "cccccc"),
|
||||
]
|
||||
images = ["11111", "22222", "33333"]
|
||||
chunker = RequestChunker(build_messages, max_bytes=12, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, images)
|
||||
|
||||
self.assertGreaterEqual(len(chunks), 3)
|
||||
image_counts = [len(c.image_urls) for c in chunks]
|
||||
self.assertGreater(image_counts[1], 0)
|
||||
self.assertGreater(image_counts[2], 0)
|
||||
all_images = [img for c in chunks for img in c.image_urls]
|
||||
self.assertEqual(all_images, images)
|
||||
|
||||
def test_split_oversized_segment(self):
|
||||
segments = [DummySeg(0, 1, "x" * 25)]
|
||||
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, [])
|
||||
combined = "".join(seg.text for c in chunks for seg in c.segments)
|
||||
self.assertEqual(combined, "x" * 25)
|
||||
|
||||
def test_group_texts_by_budget(self):
|
||||
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
||||
|
||||
def build_text_messages(texts, *_args, **_kwargs):
|
||||
content = [{"type": "text", "text": "".join(texts)}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
groups = chunker.group_texts_by_budget(["aaaaa", "bbbbb", "ccccc"], build_text_messages)
|
||||
self.assertEqual(groups, [["aaaaa", "bbbbb"], ["ccccc"]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
35
backend/tests/test_screenshot_marker.py
Normal file
35
backend/tests/test_screenshot_marker.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import importlib.util
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
|
||||
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "app" / "utils" / "screenshot_marker.py"
|
||||
spec = importlib.util.spec_from_file_location("screenshot_marker", MODULE_PATH)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError("screenshot_marker module spec not found")
|
||||
screenshot_marker = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(screenshot_marker)
|
||||
extract_screenshot_timestamps = screenshot_marker.extract_screenshot_timestamps
|
||||
|
||||
|
||||
class TestScreenshotMarker(unittest.TestCase):
|
||||
def test_extract_accepts_star_bracket_format(self):
|
||||
markdown = "A\n*Screenshot-[01:02]\nB"
|
||||
matches = extract_screenshot_timestamps(markdown)
|
||||
self.assertEqual(matches, [("*Screenshot-[01:02]", 62)])
|
||||
|
||||
def test_extract_accepts_legacy_formats(self):
|
||||
markdown = "*Screenshot-03:04 and Screenshot-[05:06]"
|
||||
matches = extract_screenshot_timestamps(markdown)
|
||||
self.assertEqual(
|
||||
matches,
|
||||
[
|
||||
("*Screenshot-03:04", 184),
|
||||
("Screenshot-[05:06]", 306),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
42
backend/tests/test_task_serial_executor.py
Normal file
42
backend/tests/test_task_serial_executor.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import importlib.util
|
||||
import pathlib
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
|
||||
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "app" / "services" / "task_serial_executor.py"
|
||||
spec = importlib.util.spec_from_file_location("task_serial_executor", MODULE_PATH)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError("task_serial_executor module spec not found")
|
||||
task_serial_executor = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(task_serial_executor)
|
||||
SerialTaskExecutor = task_serial_executor.SerialTaskExecutor
|
||||
|
||||
|
||||
class TestTaskSerialExecutor(unittest.TestCase):
|
||||
def test_executor_runs_tasks_one_by_one(self):
|
||||
executor = SerialTaskExecutor()
|
||||
state_lock = threading.Lock()
|
||||
state = {"active": 0, "peak_active": 0}
|
||||
|
||||
def critical_work():
|
||||
with state_lock:
|
||||
state["active"] += 1
|
||||
state["peak_active"] = max(state["peak_active"], state["active"])
|
||||
time.sleep(0.05)
|
||||
with state_lock:
|
||||
state["active"] -= 1
|
||||
|
||||
threads = [threading.Thread(target=lambda: executor.run(critical_work)) for _ in range(2)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(state["peak_active"], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
147
backend/tests/test_universal_gpt_checkpoint.py
Normal file
147
backend/tests/test_universal_gpt_checkpoint.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _install_stubs():
|
||||
app_mod = types.ModuleType("app")
|
||||
gpt_pkg = types.ModuleType("app.gpt")
|
||||
models_pkg = types.ModuleType("app.models")
|
||||
|
||||
base_mod = types.ModuleType("app.gpt.base")
|
||||
|
||||
class _GPT:
|
||||
pass
|
||||
|
||||
base_mod.GPT = _GPT
|
||||
|
||||
prompt_builder_mod = types.ModuleType("app.gpt.prompt_builder")
|
||||
|
||||
def _generate_base_prompt(**_kwargs):
|
||||
return "prompt"
|
||||
|
||||
prompt_builder_mod.generate_base_prompt = _generate_base_prompt
|
||||
|
||||
prompt_mod = types.ModuleType("app.gpt.prompt")
|
||||
prompt_mod.BASE_PROMPT = ""
|
||||
prompt_mod.AI_SUM = ""
|
||||
prompt_mod.SCREENSHOT = ""
|
||||
prompt_mod.LINK = ""
|
||||
prompt_mod.MERGE_PROMPT = "merge"
|
||||
|
||||
utils_mod = types.ModuleType("app.gpt.utils")
|
||||
|
||||
def _fix_markdown(text):
|
||||
return text
|
||||
|
||||
utils_mod.fix_markdown = _fix_markdown
|
||||
|
||||
request_chunker_mod = types.ModuleType("app.gpt.request_chunker")
|
||||
|
||||
class _RequestChunker:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def group_texts_by_budget(self, texts, _builder, **_kwargs):
|
||||
return [texts]
|
||||
|
||||
request_chunker_mod.RequestChunker = _RequestChunker
|
||||
|
||||
gpt_model_mod = types.ModuleType("app.models.gpt_model")
|
||||
|
||||
class _GPTSource:
|
||||
pass
|
||||
|
||||
gpt_model_mod.GPTSource = _GPTSource
|
||||
|
||||
transcriber_model_mod = types.ModuleType("app.models.transcriber_model")
|
||||
|
||||
class _TranscriptSegment:
|
||||
def __init__(self, **kwargs):
|
||||
self.start = kwargs.get("start", 0)
|
||||
self.end = kwargs.get("end", 0)
|
||||
self.text = kwargs.get("text", "")
|
||||
|
||||
transcriber_model_mod.TranscriptSegment = _TranscriptSegment
|
||||
|
||||
sys.modules.setdefault("app", app_mod)
|
||||
sys.modules.setdefault("app.gpt", gpt_pkg)
|
||||
sys.modules.setdefault("app.models", models_pkg)
|
||||
sys.modules["app.gpt.base"] = base_mod
|
||||
sys.modules["app.gpt.prompt_builder"] = prompt_builder_mod
|
||||
sys.modules["app.gpt.prompt"] = prompt_mod
|
||||
sys.modules["app.gpt.utils"] = utils_mod
|
||||
sys.modules["app.gpt.request_chunker"] = request_chunker_mod
|
||||
sys.modules["app.models.gpt_model"] = gpt_model_mod
|
||||
sys.modules["app.models.transcriber_model"] = transcriber_model_mod
|
||||
|
||||
|
||||
def _load_universal_gpt_class():
|
||||
_install_stubs()
|
||||
root = pathlib.Path(__file__).resolve().parents[1]
|
||||
module_path = root / "app" / "gpt" / "universal_gpt.py"
|
||||
spec = importlib.util.spec_from_file_location("universal_gpt", module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError("universal_gpt module spec not found")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module.UniversalGPT
|
||||
|
||||
|
||||
UniversalGPT = _load_universal_gpt_class()
|
||||
|
||||
|
||||
class _FailingCompletions:
|
||||
def create(self, **_kwargs):
|
||||
raise Exception("Error code: 524 - bad_response_status_code")
|
||||
|
||||
|
||||
class _DummyChat:
|
||||
def __init__(self):
|
||||
self.completions = _FailingCompletions()
|
||||
|
||||
|
||||
class _DummyModels:
|
||||
@staticmethod
|
||||
def list():
|
||||
return []
|
||||
|
||||
|
||||
class _DummyClient:
|
||||
def __init__(self):
|
||||
self.chat = _DummyChat()
|
||||
self.models = _DummyModels()
|
||||
|
||||
|
||||
class TestUniversalGPTCheckpoint(unittest.TestCase):
|
||||
def test_merge_524_error_persists_checkpoint(self):
|
||||
original_attempts = os.environ.get("OPENAI_RETRY_ATTEMPTS")
|
||||
os.environ["OPENAI_RETRY_ATTEMPTS"] = "1"
|
||||
gpt = UniversalGPT(_DummyClient(), model="mock-model")
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
gpt.checkpoint_dir = Path(tmp_dir)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
gpt._merge_partials(["part-a", "part-b"], "task-1", "sig-1")
|
||||
|
||||
checkpoint_path = gpt._checkpoint_path("task-1")
|
||||
self.assertTrue(checkpoint_path.exists())
|
||||
payload = json.loads(checkpoint_path.read_text(encoding="utf-8"))
|
||||
self.assertEqual(payload["phase"], "merge")
|
||||
self.assertEqual(payload["partials"], ["part-a", "part-b"])
|
||||
finally:
|
||||
if original_attempts is None:
|
||||
os.environ.pop("OPENAI_RETRY_ATTEMPTS", None)
|
||||
else:
|
||||
os.environ["OPENAI_RETRY_ATTEMPTS"] = original_attempts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
142
backend/tests/test_video_reader_dedupe.py
Normal file
142
backend/tests/test_video_reader_dedupe.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import importlib.util
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _install_stubs():
|
||||
app_mod = types.ModuleType("app")
|
||||
utils_pkg = types.ModuleType("app.utils")
|
||||
|
||||
logger_mod = types.ModuleType("app.utils.logger")
|
||||
|
||||
class _Logger:
|
||||
@staticmethod
|
||||
def info(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def warning(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def error(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def _get_logger(_name):
|
||||
return _Logger()
|
||||
|
||||
logger_mod.get_logger = _get_logger
|
||||
|
||||
path_helper_mod = types.ModuleType("app.utils.path_helper")
|
||||
ffmpeg_mod = types.ModuleType("ffmpeg")
|
||||
|
||||
pil_mod = types.ModuleType("PIL")
|
||||
pil_image_mod = types.ModuleType("PIL.Image")
|
||||
pil_draw_mod = types.ModuleType("PIL.ImageDraw")
|
||||
pil_font_mod = types.ModuleType("PIL.ImageFont")
|
||||
|
||||
class _FakeImage:
|
||||
pass
|
||||
|
||||
class _FakeImageDraw:
|
||||
@staticmethod
|
||||
def Draw(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class _FakeImageFont:
|
||||
@staticmethod
|
||||
def truetype(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def load_default():
|
||||
return None
|
||||
|
||||
pil_image_mod.Image = _FakeImage
|
||||
pil_draw_mod.ImageDraw = _FakeImageDraw
|
||||
pil_font_mod.ImageFont = _FakeImageFont
|
||||
|
||||
def _get_app_dir(name):
|
||||
return name
|
||||
|
||||
path_helper_mod.get_app_dir = _get_app_dir
|
||||
ffmpeg_mod.probe = lambda *_args, **_kwargs: {"format": {"duration": "0"}}
|
||||
|
||||
sys.modules.setdefault("app", app_mod)
|
||||
sys.modules.setdefault("app.utils", utils_pkg)
|
||||
sys.modules["PIL"] = pil_mod
|
||||
sys.modules["PIL.Image"] = pil_image_mod
|
||||
sys.modules["PIL.ImageDraw"] = pil_draw_mod
|
||||
sys.modules["PIL.ImageFont"] = pil_font_mod
|
||||
sys.modules["ffmpeg"] = ffmpeg_mod
|
||||
sys.modules["app.utils.logger"] = logger_mod
|
||||
sys.modules["app.utils.path_helper"] = path_helper_mod
|
||||
|
||||
|
||||
def _load_video_reader_module():
|
||||
_install_stubs()
|
||||
root = pathlib.Path(__file__).resolve().parents[1]
|
||||
module_path = root / "app" / "utils" / "video_reader.py"
|
||||
spec = importlib.util.spec_from_file_location("video_reader", module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError("video_reader module spec not found")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
video_reader_module = _load_video_reader_module()
|
||||
VideoReader = video_reader_module.VideoReader
|
||||
|
||||
|
||||
def _make_fake_ffmpeg_runner(colors_by_second):
|
||||
def _runner(cmd, check=True):
|
||||
output_path = next((arg for arg in cmd if isinstance(arg, str) and arg.endswith(".jpg")), None)
|
||||
if output_path is None:
|
||||
raise AssertionError("Output path not found in ffmpeg cmd")
|
||||
match = re.search(r"frame_(\d{2})_(\d{2})\.jpg$", output_path)
|
||||
if match is None:
|
||||
raise AssertionError("Unexpected output path")
|
||||
sec = int(match.group(1)) * 60 + int(match.group(2))
|
||||
payload = colors_by_second[sec]
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(payload)
|
||||
return 0
|
||||
|
||||
return _runner
|
||||
|
||||
|
||||
class TestVideoReaderDeduplicateFrames(unittest.TestCase):
|
||||
def test_extract_frames_skips_adjacent_duplicates_when_enabled(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
frame_dir = pathlib.Path(tmp_dir) / "frames"
|
||||
grid_dir = pathlib.Path(tmp_dir) / "grids"
|
||||
reader = VideoReader(
|
||||
video_path="dummy.mp4",
|
||||
frame_interval=1,
|
||||
frame_dir=str(frame_dir),
|
||||
grid_dir=str(grid_dir),
|
||||
)
|
||||
|
||||
fake_colors = {
|
||||
0: b"frame-a",
|
||||
1: b"frame-a",
|
||||
2: b"frame-b",
|
||||
3: b"frame-b",
|
||||
}
|
||||
|
||||
with patch.object(video_reader_module.ffmpeg, "probe", return_value={"format": {"duration": "4"}}), \
|
||||
patch.object(video_reader_module.subprocess, "run", side_effect=_make_fake_ffmpeg_runner(fake_colors)):
|
||||
paths = reader.extract_frames(max_frames=10)
|
||||
|
||||
names = [pathlib.Path(p).name for p in paths]
|
||||
self.assertEqual(names, ["frame_00_00.jpg", "frame_00_02.jpg"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user