mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
162 lines
5.9 KiB
Python
162 lines
5.9 KiB
Python
from dataclasses import dataclass
|
|
from typing import Callable, List, Optional
|
|
|
|
|
|
@dataclass
|
|
class ChunkPayload:
|
|
segments: list
|
|
image_urls: list
|
|
|
|
|
|
class RequestChunker:
|
|
def __init__(self, message_builder: Callable, max_bytes: int, size_estimator: Optional[Callable] = None):
|
|
self.message_builder = message_builder
|
|
self.max_bytes = max_bytes
|
|
self.size_estimator = size_estimator
|
|
|
|
def estimate(self, messages) -> int:
|
|
if self.size_estimator:
|
|
return self.size_estimator(messages)
|
|
import json
|
|
return len(json.dumps(messages, ensure_ascii=False).encode("utf-8"))
|
|
|
|
def _messages_size(self, segments, image_urls, **kwargs) -> int:
|
|
messages = self.message_builder(segments, image_urls, **kwargs)
|
|
return self.estimate(messages)
|
|
|
|
def _get_text(self, segment) -> str:
|
|
if isinstance(segment, dict):
|
|
return segment.get("text", "")
|
|
return getattr(segment, "text", "")
|
|
|
|
def _make_segment(self, segment, text: str):
|
|
if isinstance(segment, dict):
|
|
new_seg = dict(segment)
|
|
new_seg["text"] = text
|
|
return new_seg
|
|
if hasattr(segment, "__dict__"):
|
|
data = dict(segment.__dict__)
|
|
data["text"] = text
|
|
return type(segment)(**data)
|
|
return type(segment)(segment.start, segment.end, text)
|
|
|
|
def _split_segment_to_fit(self, segment, **kwargs):
|
|
text = self._get_text(segment)
|
|
if not text:
|
|
raise ValueError("empty segment cannot be split")
|
|
lo, hi = 1, len(text)
|
|
best = None
|
|
while lo <= hi:
|
|
mid = (lo + hi) // 2
|
|
candidate = self._make_segment(segment, text[:mid])
|
|
size = self._messages_size([candidate], [], **kwargs)
|
|
if size <= self.max_bytes:
|
|
best = mid
|
|
lo = mid + 1
|
|
else:
|
|
hi = mid - 1
|
|
if best is None:
|
|
raise ValueError("single segment too large to fit request")
|
|
head = self._make_segment(segment, text[:best])
|
|
tail = self._make_segment(segment, text[best:])
|
|
return head, tail
|
|
|
|
def chunk(self, segments: list, image_urls: list, **kwargs) -> List[ChunkPayload]:
|
|
segments = list(segments or [])
|
|
image_urls = list(image_urls or [])
|
|
if not segments and not image_urls:
|
|
return []
|
|
|
|
chunks: List[ChunkPayload] = []
|
|
seg_idx = 0
|
|
|
|
while seg_idx < len(segments):
|
|
batch_segments = []
|
|
while seg_idx < len(segments):
|
|
candidate = batch_segments + [segments[seg_idx]]
|
|
size = self._messages_size(candidate, [], **kwargs)
|
|
if size <= self.max_bytes:
|
|
batch_segments = candidate
|
|
seg_idx += 1
|
|
continue
|
|
if not batch_segments:
|
|
head, tail = self._split_segment_to_fit(segments[seg_idx], **kwargs)
|
|
segments[seg_idx] = head
|
|
segments.insert(seg_idx + 1, tail)
|
|
continue
|
|
break
|
|
|
|
if not batch_segments:
|
|
raise ValueError("unable to fit any content into chunk")
|
|
|
|
chunks.append(ChunkPayload(segments=batch_segments, image_urls=[]))
|
|
|
|
if not image_urls:
|
|
return chunks
|
|
|
|
if not chunks:
|
|
chunks = [ChunkPayload(segments=[], image_urls=[])]
|
|
|
|
if not segments:
|
|
for image in image_urls:
|
|
appended = False
|
|
for chunk in chunks[-1:]:
|
|
candidate_images = chunk.image_urls + [image]
|
|
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
|
|
chunk.image_urls = candidate_images
|
|
appended = True
|
|
break
|
|
|
|
if appended:
|
|
continue
|
|
|
|
if self._messages_size([], [image], **kwargs) > self.max_bytes:
|
|
raise ValueError("single image payload exceeds max_bytes")
|
|
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
|
|
return chunks
|
|
|
|
chunk_count = len(chunks)
|
|
total_images = len(image_urls)
|
|
for idx, image in enumerate(image_urls):
|
|
preferred_idx = min(chunk_count - 1, (idx * chunk_count) // total_images)
|
|
placed = False
|
|
|
|
for chunk_idx in range(preferred_idx, len(chunks)):
|
|
chunk = chunks[chunk_idx]
|
|
candidate_images = chunk.image_urls + [image]
|
|
if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes:
|
|
chunk.image_urls = candidate_images
|
|
placed = True
|
|
break
|
|
|
|
if placed:
|
|
continue
|
|
|
|
if self._messages_size([], [image], **kwargs) > self.max_bytes:
|
|
raise ValueError("single image payload exceeds max_bytes")
|
|
chunks.append(ChunkPayload(segments=[], image_urls=[image]))
|
|
|
|
return chunks
|
|
|
|
def group_texts_by_budget(self, texts: List[str], build_messages: Callable, **kwargs) -> List[List[str]]:
|
|
groups: List[List[str]] = []
|
|
idx = 0
|
|
while idx < len(texts):
|
|
group: List[str] = []
|
|
while idx < len(texts):
|
|
candidate = group + [texts[idx]]
|
|
try:
|
|
messages = build_messages(candidate, [], **kwargs)
|
|
except TypeError:
|
|
messages = build_messages(candidate, **kwargs)
|
|
size = self.estimate(messages)
|
|
if size <= self.max_bytes:
|
|
group = candidate
|
|
idx += 1
|
|
continue
|
|
if not group:
|
|
raise ValueError("single text block exceeds max_bytes")
|
|
break
|
|
groups.append(group)
|
|
return groups
|