mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-07 05:43:01 +08:00
98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
import importlib.util
|
|
import pathlib
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
|
|
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
|
MODULE_PATH = ROOT / "app" / "gpt" / "request_chunker.py"
|
|
spec = importlib.util.spec_from_file_location("request_chunker", MODULE_PATH)
|
|
if spec is None or spec.loader is None:
|
|
raise ImportError("request_chunker module spec not found")
|
|
request_chunker = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(request_chunker)
|
|
RequestChunker = request_chunker.RequestChunker
|
|
|
|
|
|
@dataclass
|
|
class DummySeg:
|
|
start: float
|
|
end: float
|
|
text: str
|
|
|
|
|
|
def build_messages(segments, image_urls, **_):
|
|
content = [{"type": "text", "text": "".join(s.text for s in segments)}]
|
|
for url in image_urls:
|
|
content.append({"type": "image_url", "image_url": {"url": url, "detail": "auto"}})
|
|
return [{"role": "user", "content": content}]
|
|
|
|
|
|
def size_estimator(messages):
|
|
size = 0
|
|
for part in messages[0]["content"]:
|
|
if part["type"] == "text":
|
|
size += len(part["text"])
|
|
else:
|
|
size += len(part["image_url"]["url"])
|
|
return size
|
|
|
|
|
|
class TestRequestChunker(unittest.TestCase):
|
|
def test_chunk_segments_preserves_order_and_content(self):
|
|
segments = [
|
|
DummySeg(0, 1, "aaaa"),
|
|
DummySeg(1, 2, "bbbb"),
|
|
DummySeg(2, 3, "cccc"),
|
|
]
|
|
chunker = RequestChunker(build_messages, max_bytes=8, size_estimator=size_estimator)
|
|
chunks = chunker.chunk(segments, [])
|
|
texts = ["".join(seg.text for seg in c.segments) for c in chunks]
|
|
self.assertEqual("".join(texts), "aaaabbbbcccc")
|
|
self.assertTrue(all(texts))
|
|
|
|
def test_chunk_images_distributed_across_batches(self):
|
|
segments = [DummySeg(0, 1, "aa")]
|
|
images = ["i" * 6, "j" * 6, "k" * 6]
|
|
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
|
chunks = chunker.chunk(segments, images)
|
|
all_images = [img for c in chunks for img in c.image_urls]
|
|
self.assertEqual(all_images, images)
|
|
|
|
def test_chunk_images_are_not_front_loaded_when_multiple_segment_chunks(self):
|
|
segments = [
|
|
DummySeg(0, 1, "aaaaaa"),
|
|
DummySeg(1, 2, "bbbbbb"),
|
|
DummySeg(2, 3, "cccccc"),
|
|
]
|
|
images = ["11111", "22222", "33333"]
|
|
chunker = RequestChunker(build_messages, max_bytes=12, size_estimator=size_estimator)
|
|
chunks = chunker.chunk(segments, images)
|
|
|
|
self.assertGreaterEqual(len(chunks), 3)
|
|
image_counts = [len(c.image_urls) for c in chunks]
|
|
self.assertGreater(image_counts[1], 0)
|
|
self.assertGreater(image_counts[2], 0)
|
|
all_images = [img for c in chunks for img in c.image_urls]
|
|
self.assertEqual(all_images, images)
|
|
|
|
def test_split_oversized_segment(self):
|
|
segments = [DummySeg(0, 1, "x" * 25)]
|
|
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
|
chunks = chunker.chunk(segments, [])
|
|
combined = "".join(seg.text for c in chunks for seg in c.segments)
|
|
self.assertEqual(combined, "x" * 25)
|
|
|
|
def test_group_texts_by_budget(self):
|
|
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
|
|
|
def build_text_messages(texts, *_args, **_kwargs):
|
|
content = [{"type": "text", "text": "".join(texts)}]
|
|
return [{"role": "user", "content": content}]
|
|
|
|
groups = chunker.group_texts_by_budget(["aaaaa", "bbbbb", "ccccc"], build_text_messages)
|
|
self.assertEqual(groups, [["aaaaa", "bbbbb"], ["ccccc"]])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|