From 0871548b07dd330118133d70ab53d7a08d46485a Mon Sep 17 00:00:00 2001 From: snaily Date: Thu, 6 Mar 2025 15:53:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E4=BC=98=E5=8C=96=E5=99=A8=E4=BB=A5=E6=94=B9?= =?UTF-8?q?=E5=96=84=E8=81=8A=E5=A4=A9=E4=BD=93=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增StreamOptimizer类用于优化API响应的流式输出 实现智能延迟调整算法,根据文本长度动态计算延迟时间 添加长文本分块输出功能,提高大段文本的显示效果 将优化器集成到Gemini和OpenAI聊天服务中 优化后的输出更接近自然打字效果,提升用户体验 --- app/services/chat/stream_optimizer.py | 116 ++++++++++++++++++++++++++ app/services/gemini_chat_service.py | 38 ++++++++- app/services/openai_chat_service.py | 48 ++++++++++- 3 files changed, 197 insertions(+), 5 deletions(-) create mode 100644 app/services/chat/stream_optimizer.py diff --git a/app/services/chat/stream_optimizer.py b/app/services/chat/stream_optimizer.py new file mode 100644 index 0000000..9e81148 --- /dev/null +++ b/app/services/chat/stream_optimizer.py @@ -0,0 +1,116 @@ +# app/services/chat/stream_optimizer.py + +import asyncio +import math +from typing import Any, List, AsyncGenerator, Callable +from app.core.logger import get_openai_logger, get_gemini_logger + +logger_openai = get_openai_logger() +logger_gemini = get_gemini_logger() + + +class StreamOptimizer: + """流式输出优化器 + + 提供流式输出优化功能,包括智能延迟调整和长文本分块输出。 + """ + + def __init__(self, + logger=None, + min_delay: float = 0.016, + max_delay: float = 0.024, + short_text_threshold: int = 10, + long_text_threshold: int = 100, + chunk_size: int = 10): + """初始化流式输出优化器 + + 参数: + logger: 日志记录器 + min_delay: 最小延迟时间(秒) + max_delay: 最大延迟时间(秒) + short_text_threshold: 短文本阈值(字符数) + long_text_threshold: 长文本阈值(字符数) + chunk_size: 长文本分块大小(字符数) + """ + self.logger = logger + self.min_delay = min_delay + self.max_delay = max_delay + self.short_text_threshold = short_text_threshold + self.long_text_threshold = long_text_threshold + self.chunk_size = chunk_size + + def calculate_delay(self, text_length: int) -> float: + """根据文本长度计算延迟时间 + + 参数: + text_length: 文本长度 + + 返回: + 延迟时间(秒) + """ + if text_length <= self.short_text_threshold: + # 短文本使用较大延迟 + return self.max_delay + elif text_length >= self.long_text_threshold: + # 长文本使用较小延迟 + return self.min_delay + else: + # 中等长度文本使用线性插值计算延迟 + # 使用对数函数使延迟变化更平滑 + ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold) + return self.max_delay - ratio * (self.max_delay - self.min_delay) + + def split_text_into_chunks(self, text: str) -> List[str]: + """将文本分割成小块 + + 参数: + text: 要分割的文本 + + 返回: + 文本块列表 + """ + return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)] + + async def optimize_stream_output(self, + text: str, + create_response_chunk: Callable[[str], Any], + format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]: + """优化流式输出 + + 参数: + text: 要输出的文本 + create_response_chunk: 创建响应块的函数,接收文本,返回响应块 + format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串 + + 返回: + 异步生成器,生成格式化后的响应块 + """ + if not text: + return + + # 计算智能延迟时间 + delay = self.calculate_delay(len(text)) + if self.logger: + self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s") + + # 根据文本长度决定输出方式 + if len(text) >= self.long_text_threshold: + # 长文本:分块输出 + chunks = self.split_text_into_chunks(text) + if self.logger: + self.logger.info(f"Long text: splitting into {len(chunks)} chunks") + for chunk_text in chunks: + chunk_response = create_response_chunk(chunk_text) + yield format_chunk(chunk_response) + await asyncio.sleep(delay) + else: + # 短文本:逐字符输出 + for char in text: + char_chunk = create_response_chunk(char) + yield format_chunk(char_chunk) + await asyncio.sleep(delay) + + +# 创建默认的优化器实例,可以直接导入使用 +openai_optimizer = StreamOptimizer(logger=logger_openai) +gemini_optimizer = StreamOptimizer(logger=logger_gemini) diff --git a/app/services/gemini_chat_service.py b/app/services/gemini_chat_service.py index d281679..51bac9f 100644 --- a/app/services/gemini_chat_service.py +++ b/app/services/gemini_chat_service.py @@ -4,6 +4,7 @@ import json from typing import Dict, Any, AsyncGenerator, List from app.core.logger import get_gemini_logger from app.services.chat.api_client import GeminiApiClient +from app.services.chat.stream_optimizer import gemini_optimizer from app.schemas.gemini_models import GeminiRequest from app.core.config import settings from app.services.chat.response_handler import GeminiResponseHandler @@ -78,6 +79,26 @@ class GeminiChatService: self.api_client = GeminiApiClient(base_url) self.key_manager = key_manager self.response_handler = GeminiResponseHandler() + + def _extract_text_from_response(self, response: Dict[str, Any]) -> str: + """从响应中提取文本内容""" + if not response.get("candidates"): + return "" + + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + + if parts and "text" in parts[0]: + return parts[0].get("text", "") + return "" + + def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]: + """创建包含指定文本的响应""" + response_copy = json.loads(json.dumps(original_response)) # 深拷贝 + if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"): + response_copy["candidates"][0]["content"]["parts"][0]["text"] = text + return response_copy async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]: """生成内容""" @@ -96,8 +117,21 @@ class GeminiChatService: # print(line) if line.startswith("data:"): line = line[6:] - line = json.dumps(self.response_handler.handle_response(json.loads(line), model, stream=True)) - yield "data: " + line + "\n\n" + response_data = self.response_handler.handle_response(json.loads(line), model, stream=True) + text = self._extract_text_from_response(response_data) + + # 如果有文本内容,使用流式输出优化器处理 + if text: + # 使用流式输出优化器处理文本输出 + async for optimized_chunk in gemini_optimizer.optimize_stream_output( + text, + lambda t: self._create_char_response(response_data, t), + lambda c: "data: " + json.dumps(c) + "\n\n" + ): + yield optimized_chunk + else: + # 如果没有文本内容(如工具调用等),整块输出 + yield "data: " + json.dumps(response_data) + "\n\n" logger.info("Streaming completed successfully") break except Exception as e: diff --git a/app/services/openai_chat_service.py b/app/services/openai_chat_service.py index bd8c566..3c531e5 100644 --- a/app/services/openai_chat_service.py +++ b/app/services/openai_chat_service.py @@ -7,6 +7,7 @@ from app.core.logger import get_openai_logger from app.services.chat.message_converter import OpenAIMessageConverter from app.services.chat.response_handler import OpenAIResponseHandler from app.services.chat.api_client import GeminiApiClient +from app.services.chat.stream_optimizer import openai_optimizer from app.schemas.openai_models import ChatRequest, ImageGenerationRequest from app.core.config import settings from app.services.image_create_service import ImageCreateService @@ -129,6 +130,23 @@ class OpenAIChatService: self.api_client = GeminiApiClient(base_url) self.key_manager = key_manager self.image_create_service = ImageCreateService() + + def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str: + """从OpenAI响应块中提取文本内容""" + if not chunk.get("choices"): + return "" + + choice = chunk["choices"][0] + if "delta" in choice and "content" in choice["delta"]: + return choice["delta"]["content"] + return "" + + def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]: + """创建包含指定文本的OpenAI响应块""" + chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝 + if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]: + chunk_copy["choices"][0]["delta"]["content"] = text + return chunk_copy async def create_chat_completion( self, @@ -173,7 +191,19 @@ class OpenAIChatService: chunk, model, stream=True, finish_reason=None ) if openai_chunk: - yield f"data: {json.dumps(openai_chunk)}\n\n" + # 提取文本内容 + text = self._extract_text_from_openai_chunk(openai_chunk) + if text: + # 使用流式输出优化器处理文本输出 + async for optimized_chunk in openai_optimizer.optimize_stream_output( + text, + lambda t: self._create_char_openai_chunk(openai_chunk, t), + lambda c: f"data: {json.dumps(c)}\n\n" + ): + yield optimized_chunk + else: + # 如果没有文本内容(如工具调用等),整块输出 + yield f"data: {json.dumps(openai_chunk)}\n\n" yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" yield "data: [DONE]\n\n" logger.info("Streaming completed successfully") @@ -215,7 +245,19 @@ class OpenAIChatService: image_data, model, stream=True, finish_reason=None ) if openai_chunk: - yield f"data: {json.dumps(openai_chunk)}\n\n" + # 提取文本内容 + text = self._extract_text_from_openai_chunk(openai_chunk) + if text: + # 使用流式输出优化器处理文本输出 + async for optimized_chunk in openai_optimizer.optimize_stream_output( + text, + lambda t: self._create_char_openai_chunk(openai_chunk, t), + lambda c: f"data: {json.dumps(c)}\n\n" + ): + yield optimized_chunk + else: + # 如果没有文本内容(如图片URL等),整块输出 + yield f"data: {json.dumps(openai_chunk)}\n\n" yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" yield "data: [DONE]\n\n" logger.info("Image chat streaming completed successfully") @@ -226,4 +268,4 @@ class OpenAIChatService: return self.response_handler.handle_image_chat_response( image_data, model, stream=False, finish_reason="stop" - ) \ No newline at end of file + )