Files
BiliNote/backend/app/gpt/request_chunker.py
CyanAutumn d9a7b89e7d 🐞 fix: 增加错误之后对已解析段落的缓存功能,再次重试时不再重头开始
解析长视频时,当附件大小过大时不再调用后进行报错,而是将附件进行分批次发送

在每篇笔记开头默认增加地址来源链接,对模糊处可溯源
2026-02-12 18:28:11 +08:00

162 lines
5.9 KiB
Python

from dataclasses import dataclass
from typing import Callable, List, Optional
@dataclass
class ChunkPayload:
segments: list
image_urls: list
class RequestChunker:
def __init__(self, message_builder: Callable, max_bytes: int, size_estimator: Optional[Callable] = None):
self.message_builder = message_builder
self.max_bytes = max_bytes
self.size_estimator = size_estimator
def estimate(self, messages) -> int:
if self.size_estimator:
return self.size_estimator(messages)
import json
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
def _messages_size(self, segments, image_urls, **kwargs) -> int:
messages = self.message_builder(segments, image_urls, **kwargs)
return self.estimate(messages)
def _get_text(self, segment) -> str:
if isinstance(segment, dict):
return segment.get("text", "")
return getattr(segment, "text", "")
def _make_segment(self, segment, text: str):
if isinstance(segment, dict):
new_seg = dict(segment)
new_seg["text"] = text
return new_seg
if hasattr(segment, "__dict__"):
data = dict(segment.__dict__)
data["text"] = text
return type(segment)(**data)
return type(segment)(segment.start, segment.end, text)
def _split_segment_to_fit(self, segment, **kwargs):
text = self._get_text(segment)
if not text:
raise ValueError("empty segment cannot be split")
lo, hi = 1, len(text)
best = None
while lo <= hi:
mid = (lo + hi) // 2
candidate = self._make_segment(segment, text[:mid])
size = self._messages_size([candidate], [], **kwargs)
if size <= self.max_bytes:
best = mid
lo = mid + 1
else:
hi = mid - 1
if best is None:
raise ValueError("single segment too large to fit request")
head = self._make_segment(segment, text[:best])
tail = self._make_segment(segment, text[best:])
return head, tail
def chunk(self, segments: list, image_urls: list, **kwargs) -> List[ChunkPayload]:
segments = list(segments or [])
image_urls = list(image_urls or [])
if not segments and not image_urls:
return []
chunks: List[ChunkPayload] = []
seg_idx = 0
while seg_idx < len(segments):
batch_segments = []
while seg_idx < len(segments):
candidate = batch_segments + [segments[seg_idx]]
size = self._messages_size(candidate, [], **kwargs)
if size <= self.max_bytes:
batch_segments = candidate
seg_idx += 1
continue
if not batch_segments:
head, tail = self._split_segment_to_fit(segments[seg_idx], **kwargs)
segments[seg_idx] = head
segments.insert(seg_idx + 1, tail)
continue
break
if not batch_segments:
raise ValueError("unable to fit any content into chunk")
chunks.append(ChunkPayload(segments=batch_segments, image_urls=[]))
if not image_urls:
return chunks
if not chunks:
chunks = [ChunkPayload(segments=[], image_urls=[])]
if not segments:
for image in image_urls:
appended = False
for chunk in chunks[-1:]:
candidate_images = chunk.image_urls + [image]
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
chunk.image_urls = candidate_images
appended = True
break
if appended:
continue
if self._messages_size([], [image], **kwargs) > self.max_bytes:
raise ValueError("single image payload exceeds max_bytes")
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
return chunks
chunk_count = len(chunks)
total_images = len(image_urls)
for idx, image in enumerate(image_urls):
preferred_idx = min(chunk_count - 1, (idx * chunk_count) // total_images)
placed = False
for chunk_idx in range(preferred_idx, len(chunks)):
chunk = chunks[chunk_idx]
candidate_images = chunk.image_urls + [image]
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
chunk.image_urls = candidate_images
placed = True
break
if placed:
continue
if self._messages_size([], [image], **kwargs) > self.max_bytes:
raise ValueError("single image payload exceeds max_bytes")
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
return chunks
def group_texts_by_budget(self, texts: List[str], build_messages: Callable, **kwargs) -> List[List[str]]:
groups: List[List[str]] = []
idx = 0
while idx < len(texts):
group: List[str] = []
while idx < len(texts):
candidate = group + [texts[idx]]
try:
messages = build_messages(candidate, [], **kwargs)
except TypeError:
messages = build_messages(candidate, **kwargs)
size = self.estimate(messages)
if size <= self.max_bytes:
group = candidate
idx += 1
continue
if not group:
raise ValueError("single text block exceeds max_bytes")
break
groups.append(group)
return groups