mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-03 23:01:38 +08:00
🐞 fix: 增加错误之后对已解析段落的缓存功能,再次重试时不再重头开始
解析长视频时,当附件大小过大时不再调用后进行报错,而是将附件进行分批次发送 在每篇笔记开头默认增加地址来源链接,对模糊处可溯源
This commit is contained in:
97
backend/tests/test_request_chunker.py
Normal file
97
backend/tests/test_request_chunker.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import importlib.util
|
||||
import pathlib
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "app" / "gpt" / "request_chunker.py"
|
||||
spec = importlib.util.spec_from_file_location("request_chunker", MODULE_PATH)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError("request_chunker module spec not found")
|
||||
request_chunker = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(request_chunker)
|
||||
RequestChunker = request_chunker.RequestChunker
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySeg:
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
def build_messages(segments, image_urls, **_):
|
||||
content = [{"type": "text", "text": "".join(s.text for s in segments)}]
|
||||
for url in image_urls:
|
||||
content.append({"type": "image_url", "image_url": {"url": url, "detail": "auto"}})
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
|
||||
def size_estimator(messages):
|
||||
size = 0
|
||||
for part in messages[0]["content"]:
|
||||
if part["type"] == "text":
|
||||
size += len(part["text"])
|
||||
else:
|
||||
size += len(part["image_url"]["url"])
|
||||
return size
|
||||
|
||||
|
||||
class TestRequestChunker(unittest.TestCase):
|
||||
def test_chunk_segments_preserves_order_and_content(self):
|
||||
segments = [
|
||||
DummySeg(0, 1, "aaaa"),
|
||||
DummySeg(1, 2, "bbbb"),
|
||||
DummySeg(2, 3, "cccc"),
|
||||
]
|
||||
chunker = RequestChunker(build_messages, max_bytes=8, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, [])
|
||||
texts = ["".join(seg.text for seg in c.segments) for c in chunks]
|
||||
self.assertEqual("".join(texts), "aaaabbbbcccc")
|
||||
self.assertTrue(all(texts))
|
||||
|
||||
def test_chunk_images_distributed_across_batches(self):
|
||||
segments = [DummySeg(0, 1, "aa")]
|
||||
images = ["i" * 6, "j" * 6, "k" * 6]
|
||||
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, images)
|
||||
all_images = [img for c in chunks for img in c.image_urls]
|
||||
self.assertEqual(all_images, images)
|
||||
|
||||
def test_chunk_images_are_not_front_loaded_when_multiple_segment_chunks(self):
|
||||
segments = [
|
||||
DummySeg(0, 1, "aaaaaa"),
|
||||
DummySeg(1, 2, "bbbbbb"),
|
||||
DummySeg(2, 3, "cccccc"),
|
||||
]
|
||||
images = ["11111", "22222", "33333"]
|
||||
chunker = RequestChunker(build_messages, max_bytes=12, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, images)
|
||||
|
||||
self.assertGreaterEqual(len(chunks), 3)
|
||||
image_counts = [len(c.image_urls) for c in chunks]
|
||||
self.assertGreater(image_counts[1], 0)
|
||||
self.assertGreater(image_counts[2], 0)
|
||||
all_images = [img for c in chunks for img in c.image_urls]
|
||||
self.assertEqual(all_images, images)
|
||||
|
||||
def test_split_oversized_segment(self):
|
||||
segments = [DummySeg(0, 1, "x" * 25)]
|
||||
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
||||
chunks = chunker.chunk(segments, [])
|
||||
combined = "".join(seg.text for c in chunks for seg in c.segments)
|
||||
self.assertEqual(combined, "x" * 25)
|
||||
|
||||
def test_group_texts_by_budget(self):
|
||||
chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator)
|
||||
|
||||
def build_text_messages(texts, *_args, **_kwargs):
|
||||
content = [{"type": "text", "text": "".join(texts)}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
groups = chunker.group_texts_by_budget(["aaaaa", "bbbbb", "ccccc"], build_text_messages)
|
||||
self.assertEqual(groups, [["aaaaa", "bbbbb"], ["ccccc"]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user