From 56f6f5e198d1a44c025ba811a7e34285e491608a Mon Sep 17 00:00:00 2001 From: snaily Date: Sat, 3 May 2025 20:37:09 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=9B=BE=E5=83=8F?= =?UTF-8?q?=E7=94=9F=E6=88=90=E6=B5=81=E5=BC=8F=E5=93=8D=E5=BA=94=E5=B9=B6?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 OpenAI 兼容路由的图像生成聊天添加流式支持。 - 重构 `gemini-2.0-flash-exp` 安全设置,使用常量统一管理。 - 更改图像生成默认响应格式为 `url`。 - 启用 `.env.example` 中的 `AUTH_TOKEN`。 - 清理部分代码注释。 --- .env.example | 2 +- app/domain/openai_models.py | 2 +- app/router/openai_routes.py | 8 +++++--- app/service/chat/gemini_chat_service.py | 23 ++++++---------------- app/service/chat/openai_chat_service.py | 3 ++- app/service/embedding/embedding_service.py | 2 +- 6 files changed, 16 insertions(+), 24 deletions(-) diff --git a/.env.example b/.env.example index bf1f74a..6cf18ec 100644 --- a/.env.example +++ b/.env.example @@ -6,7 +6,7 @@ MYSQL_PASSWORD=change_me MYSQL_DATABASE=default_db API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"] ALLOWED_TOKENS=["sk-123456"] -# AUTH_TOKEN=sk-123456 +AUTH_TOKEN=sk-123456 TEST_MODEL=gemini-1.5-flash THINKING_MODELS=["gemini-2.5-flash-preview-04-17"] THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000} diff --git a/app/domain/openai_models.py b/app/domain/openai_models.py index 8e4f6ae..1f59702 100644 --- a/app/domain/openai_models.py +++ b/app/domain/openai_models.py @@ -32,4 +32,4 @@ class ImageGenerationRequest(BaseModel): size: Optional[str] = "1024x1024" quality: Optional[str] = None style: Optional[str] = None - response_format: Optional[str] = "b64_json" + response_format: Optional[str] = "url" diff --git a/app/router/openai_routes.py b/app/router/openai_routes.py index 39c4dbc..b1b993e 100644 --- a/app/router/openai_routes.py +++ b/app/router/openai_routes.py @@ -90,7 +90,11 @@ async def chat_completion( if is_image_chat: # 图像生成聊天 response = await chat_service.create_image_chat_completion(request, current_api_key) - return response # 直接返回,不处理流式 + # 处理流式响应 + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + # 非流式直接返回结果 + return response else: # 普通聊天补全 response = await chat_service.create_chat_completion(request, current_api_key) @@ -111,8 +115,6 @@ async def generate_image( operation_name = "generate_image" async with handle_route_errors(logger, operation_name): logger.info(f"Handling image generation request for prompt: {request.prompt}") - # 注意:这里假设 image_create_service.generate_images 是同步函数 - # 如果它是异步的,需要 await response = image_create_service.generate_images(request) return response diff --git a/app/service/chat/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py index 0cebace..bf1a61c 100644 --- a/app/service/chat/gemini_chat_service.py +++ b/app/service/chat/gemini_chat_service.py @@ -2,17 +2,18 @@ import json import re -import datetime # Add datetime import -import time # Add time import +import datetime +import time from typing import Any, AsyncGenerator, Dict, List from app.config.config import settings +from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS from app.domain.gemini_models import GeminiRequest from app.handler.response_handler import GeminiResponseHandler from app.handler.stream_optimizer import gemini_optimizer from app.log.logger import get_gemini_logger from app.service.client.api_client import GeminiApiClient from app.service.key.key_manager import KeyManager -from app.database.services import add_error_log, add_request_log # Import add_request_log +from app.database.services import add_error_log, add_request_log logger = get_gemini_logger() @@ -73,20 +74,8 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: def _get_safety_settings(model: str) -> List[Dict[str, str]]: """获取安全设置""" if model == "gemini-2.0-flash-exp": - return [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}, - ] - return [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, - ] + return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS + return settings.SAFETY_SETTINGS def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: diff --git a/app/service/chat/openai_chat_service.py b/app/service/chat/openai_chat_service.py index 6fbf0e7..d551bdf 100644 --- a/app/service/chat/openai_chat_service.py +++ b/app/service/chat/openai_chat_service.py @@ -8,6 +8,7 @@ from copy import deepcopy from typing import Any, AsyncGenerator, Dict, List, Optional, Union from app.config.config import settings +from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS from app.database.services import ( add_error_log, add_request_log, @@ -102,7 +103,7 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]: # and "gemini-2.0-pro-exp" not in model # ): if model == "gemini-2.0-flash-exp": - return settings.GEMINI_2_FLASH_EXP_SAFETY_SETTINGS + return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS return settings.SAFETY_SETTINGS diff --git a/app/service/embedding/embedding_service.py b/app/service/embedding/embedding_service.py index 2aac856..f78202a 100644 --- a/app/service/embedding/embedding_service.py +++ b/app/service/embedding/embedding_service.py @@ -39,7 +39,7 @@ class EmbeddingService: client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL) response = client.embeddings.create(input=input_text, model=model) is_success = True - status_code = 200 # Assume 200 OK on success + status_code = 200 return response except APIStatusError as e: is_success = False