Compare commits

...

10 Commits

Author SHA1 Message Date
snaily
fb523f4a2e feat: 将 StreamOptimizer 参数改为可配置
将 StreamOptimizer 中的硬编码参数改为通过配置文件可配置的参数,提高了系统的灵活性。具体修改包括:

在 .env.example 中添加 stream_optimizer 相关配置参数
在 app/core/config.py 中添加对应的配置项
修改 app/services/chat/stream_optimizer.py 从配置中读取参数
在 README.md 中添加流式输出优化配置的详细说明
2025-03-06 16:56:01 +08:00
snaily
40e5ffa5f4 chore: Adjust StreamOptimizer parameters for improved performance
- Reduced long_text_threshold from 100 to 50 characters
- Decreased chunk_size from 10 to 5

These changes aim to optimize the streaming output for better user experience
and responsiveness, particularly for medium-length texts.
2025-03-06 16:45:35 +08:00
snaily
0871548b07 feat: 添加流式输出优化器以改善聊天体验
新增StreamOptimizer类用于优化API响应的流式输出
实现智能延迟调整算法,根据文本长度动态计算延迟时间
添加长文本分块输出功能,提高大段文本的显示效果
将优化器集成到Gemini和OpenAI聊天服务中
优化后的输出更接近自然打字效果,提升用户体验
2025-03-06 15:53:58 +08:00
snaily
5a44a76c48 Merge remote-tracking branch 'BetterAndBetterII/main' 2025-03-03 18:45:56 +08:00
Toddy
7b5b6c7d4c if role is tool then set to user 2025-03-03 08:23:04 +00:00
Yuzhong Zhang
68ed4da789 Update Dockerfile 2025-03-03 14:09:45 +08:00
Yuzhong Zhang
cdbca7ec62 优化dockerfile,增加docker-compose,async openai 2025-03-03 13:55:09 +08:00
Yuzhong Zhang
48d58ef2e8 异步生成完成 2025-03-03 13:41:06 +08:00
snaily
88d483c1ef Merge pull request #4 from toddyoe/main
chore: add system instruction to enhance compliance with function call
2025-02-27 19:17:39 +08:00
Toddy
8d48db026c chore: add system instruction to enhance compliance with function call 2025-02-27 10:35:25 +00:00
11 changed files with 323 additions and 34 deletions

View File

@@ -13,3 +13,10 @@ CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
##########################################################################

View File

@@ -3,10 +3,10 @@ FROM python:3.10-slim
WORKDIR /app
# 复制所需文件到容器中
COPY ./app /app/app
COPY ./requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta

View File

@@ -76,6 +76,13 @@
# 图片上传配置
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
# stream_optimizer 相关配置
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
```
### 配置说明
@@ -136,6 +143,24 @@
- 用途: 用于图片上传到 SM.MS 图床
- 获取方式: 需要在 SM.MS 官网注册并获取
#### 流式输出优化配置
- `STREAM_MIN_DELAY`: 最小延迟时间
- 默认值: `0.016`(秒)
- 说明: 长文本输出时使用的最小延迟时间,值越小输出速度越快
- `STREAM_MAX_DELAY`: 最大延迟时间
- 默认值: `0.024`(秒)
- 说明: 短文本输出时使用的最大延迟时间,值越大输出速度越慢
- `STREAM_SHORT_TEXT_THRESHOLD`: 短文本阈值
- 默认值: `10`(字符)
- 说明: 小于此长度的文本被视为短文本,将使用最大延迟输出
- `STREAM_LONG_TEXT_THRESHOLD`: 长文本阈值
- 默认值: `50`(字符)
- 说明: 大于此长度的文本被视为长文本,将使用最小延迟并分块输出
- `STREAM_CHUNK_SIZE`: 长文本分块大小
- 默认值: `5`(字符)
- 说明: 长文本分块输出时,每个块的大小
### ▶️ 运行
#### 使用 Docker (推荐)

View File

@@ -63,7 +63,7 @@ async def generate_content(
logger.info(f"Using API key: {api_key}")
try:
response = chat_service.generate_content(
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
@@ -122,7 +122,7 @@ async def verify_key(api_key: str):
)
]
)
response = chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
if response:
return JSONResponse({"status": "valid"})
return JSONResponse({"status": "invalid"})

View File

@@ -17,6 +17,13 @@ class Settings(BaseSettings):
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置
STREAM_MIN_DELAY: float = 0.016
STREAM_MAX_DELAY: float = 0.024
STREAM_SHORT_TEXT_THRESHOLD: int = 10
STREAM_LONG_TEXT_THRESHOLD: int = 50
STREAM_CHUNK_SIZE: int = 5
def __init__(self):
super().__init__()

View File

@@ -24,13 +24,13 @@ class GeminiApiClient(ApiClient):
self.base_url = base_url
self.timeout = timeout
def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
with httpx.Client(timeout=timeout) as client:
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = client.post(url, json=payload)
response = await client.post(url, json=payload)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")

View File

@@ -1,14 +1,16 @@
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from typing import Any, Dict, List, Optional
SUPPORTED_ROLES = ["user", "model", "system"]
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
@@ -30,16 +32,26 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
system_instruction = None
if isinstance(msg["content"], str):
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空
if msg["content"]:
parts.append({"text": msg["content"]})
for idx, msg in enumerate(messages):
role = msg.get("role", "")
if role not in SUPPORTED_ROLES:
if role == "tool":
role = "user"
else:
# 如果是最后一条消息,则认为是用户消息
if idx == len(messages) - 1:
role = "user"
else:
role = "model"
parts = []
if isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
@@ -50,6 +62,10 @@ class OpenAIMessageConverter(MessageConverter):
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
converted_messages.append({"role": role, "parts": parts})
if parts:
if role == "system":
system_instruction = {"role": "system", "parts": parts}
else:
converted_messages.append({"role": role, "parts": parts})
return converted_messages
return converted_messages, system_instruction

View File

@@ -0,0 +1,132 @@
# app/services/chat/stream_optimizer.py
import asyncio
import math
from typing import Any, List, AsyncGenerator, Callable
from app.core.logger import get_openai_logger, get_gemini_logger
from app.core.config import settings
logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger()
class StreamOptimizer:
"""流式输出优化器
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
"""
def __init__(self,
logger=None,
min_delay: float = 0.016,
max_delay: float = 0.024,
short_text_threshold: int = 10,
long_text_threshold: int = 50,
chunk_size: int = 5):
"""初始化流式输出优化器
参数:
logger: 日志记录器
min_delay: 最小延迟时间(秒)
max_delay: 最大延迟时间(秒)
short_text_threshold: 短文本阈值(字符数)
long_text_threshold: 长文本阈值(字符数)
chunk_size: 长文本分块大小(字符数)
"""
self.logger = logger
self.min_delay = min_delay
self.max_delay = max_delay
self.short_text_threshold = short_text_threshold
self.long_text_threshold = long_text_threshold
self.chunk_size = chunk_size
def calculate_delay(self, text_length: int) -> float:
"""根据文本长度计算延迟时间
参数:
text_length: 文本长度
返回:
延迟时间(秒)
"""
if text_length <= self.short_text_threshold:
# 短文本使用较大延迟
return self.max_delay
elif text_length >= self.long_text_threshold:
# 长文本使用较小延迟
return self.min_delay
else:
# 中等长度文本使用线性插值计算延迟
# 使用对数函数使延迟变化更平滑
ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold)
return self.max_delay - ratio * (self.max_delay - self.min_delay)
def split_text_into_chunks(self, text: str) -> List[str]:
"""将文本分割成小块
参数:
text: 要分割的文本
返回:
文本块列表
"""
return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)]
async def optimize_stream_output(self,
text: str,
create_response_chunk: Callable[[str], Any],
format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]:
"""优化流式输出
参数:
text: 要输出的文本
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
返回:
异步生成器,生成格式化后的响应块
"""
if not text:
return
# 计算智能延迟时间
delay = self.calculate_delay(len(text))
if self.logger:
self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
# 根据文本长度决定输出方式
if len(text) >= self.long_text_threshold:
# 长文本:分块输出
chunks = self.split_text_into_chunks(text)
if self.logger:
self.logger.info(f"Long text: splitting into {len(chunks)} chunks")
for chunk_text in chunks:
chunk_response = create_response_chunk(chunk_text)
yield format_chunk(chunk_response)
await asyncio.sleep(delay)
else:
# 短文本:逐字符输出
for char in text:
char_chunk = create_response_chunk(char)
yield format_chunk(char_chunk)
await asyncio.sleep(delay)
# 创建默认的优化器实例,可以直接导入使用
openai_optimizer = StreamOptimizer(
logger=logger_openai,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)
gemini_optimizer = StreamOptimizer(
logger=logger_gemini,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)

View File

@@ -4,6 +4,7 @@ import json
from typing import Dict, Any, AsyncGenerator, List
from app.core.logger import get_gemini_logger
from app.services.chat.api_client import GeminiApiClient
from app.services.chat.stream_optimizer import gemini_optimizer
from app.schemas.gemini_models import GeminiRequest
from app.core.config import settings
from app.services.chat.response_handler import GeminiResponseHandler
@@ -78,11 +79,31 @@ class GeminiChatService:
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
"""从响应中提取文本内容"""
if not response.get("candidates"):
return ""
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if parts and "text" in parts[0]:
return parts[0].get("text", "")
return ""
def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]:
"""创建包含指定文本的响应"""
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"):
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
return response_copy
def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
"""生成内容"""
payload = _build_payload(model, request)
response = self.api_client.generate_content(payload, model, api_key)
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
@@ -96,8 +117,21 @@ class GeminiChatService:
# print(line)
if line.startswith("data:"):
line = line[6:]
line = json.dumps(self.response_handler.handle_response(json.loads(line), model, stream=True))
yield "data: " + line + "\n\n"
response_data = self.response_handler.handle_response(json.loads(line), model, stream=True)
text = self._extract_text_from_response(response_data)
# 如果有文本内容,使用流式输出优化器处理
if text:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in gemini_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_response(response_data, t),
lambda c: "data: " + json.dumps(c) + "\n\n"
):
yield optimized_chunk
else:
# 如果没有文本内容(如工具调用等),整块输出
yield "data: " + json.dumps(response_data) + "\n\n"
logger.info("Streaming completed successfully")
break
except Exception as e:

View File

@@ -2,11 +2,12 @@
from copy import deepcopy
import json
from typing import Dict, Any, AsyncGenerator, List, Union
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
from app.core.logger import get_openai_logger
from app.services.chat.message_converter import OpenAIMessageConverter
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.api_client import GeminiApiClient
from app.services.chat.stream_optimizer import openai_optimizer
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
from app.core.config import settings
from app.services.image_create_service import ImageCreateService
@@ -57,7 +58,14 @@ def _build_tools(
function_declarations.append(function)
if function_declarations:
tools.append({"functionDeclarations": function_declarations})
# 按照 function 的 name 去重
names, functions = set(), []
for item in function_declarations:
if item.get("name") not in names:
names.add(item.get("name"))
functions.append(item)
tools.append({"functionDeclarations": functions})
return tools
@@ -87,10 +95,10 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(
request: ChatRequest, messages: List[Dict[str, Any]]
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""构建请求payload"""
return {
payload = {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
@@ -103,6 +111,16 @@ def _build_payload(
"safetySettings": _get_safety_settings(request.model),
}
if (
instruction
and isinstance(instruction, dict)
and instruction.get("role") == "system"
and instruction.get("parts")
):
payload["systemInstruction"] = instruction
return payload
class OpenAIChatService:
"""聊天服务"""
@@ -112,6 +130,23 @@ class OpenAIChatService:
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.image_create_service = ImageCreateService()
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
"""从OpenAI响应块中提取文本内容"""
if not chunk.get("choices"):
return ""
choice = chunk["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
return choice["delta"]["content"]
return ""
def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]:
"""创建包含指定文本的OpenAI响应块"""
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
chunk_copy["choices"][0]["delta"]["content"] = text
return chunk_copy
async def create_chat_completion(
self,
@@ -120,20 +155,20 @@ class OpenAIChatService:
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
messages = self.message_converter.convert(request.messages)
messages, instruction = self.message_converter.convert(request.messages)
# 构建请求payload
payload = _build_payload(request, messages)
payload = _build_payload(request, messages, instruction)
if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
return self._handle_normal_completion(request.model, payload, api_key)
return await self._handle_normal_completion(request.model, payload, api_key)
def _handle_normal_completion(
async def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = self.api_client.generate_content(payload, model, api_key)
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(
response, model, stream=False, finish_reason="stop"
)
@@ -156,7 +191,19 @@ class OpenAIChatService:
chunk, model, stream=True, finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
# 提取文本内容
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n"
):
yield optimized_chunk
else:
# 如果没有文本内容(如工具调用等),整块输出
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
@@ -198,7 +245,19 @@ class OpenAIChatService:
image_data, model, stream=True, finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
# 提取文本内容
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n"
):
yield optimized_chunk
else:
# 如果没有文本内容如图片URL等整块输出
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Image chat streaming completed successfully")
@@ -209,4 +268,4 @@ class OpenAIChatService:
return self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)
)

9
docker-compose.yml Normal file
View File

@@ -0,0 +1,9 @@
version: '3'
services:
gemini-balance:
build: .
ports:
- "8000:8000"
env_file:
- .env