feat: 支持图像生成流式响应并优化配置

- 为 OpenAI 兼容路由的图像生成聊天添加流式支持。
- 重构 `gemini-2.0-flash-exp` 安全设置,使用常量统一管理。
- 更改图像生成默认响应格式为 `url`。
- 启用 `.env.example` 中的 `AUTH_TOKEN`。
- 清理部分代码注释。
This commit is contained in:
snaily
2025-05-03 20:37:09 +08:00
parent 929592bbc4
commit 56f6f5e198
6 changed files with 16 additions and 24 deletions

View File

@@ -6,7 +6,7 @@ MYSQL_PASSWORD=change_me
MYSQL_DATABASE=default_db
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456
AUTH_TOKEN=sk-123456
TEST_MODEL=gemini-1.5-flash
THINKING_MODELS=["gemini-2.5-flash-preview-04-17"]
THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}

View File

@@ -32,4 +32,4 @@ class ImageGenerationRequest(BaseModel):
size: Optional[str] = "1024x1024"
quality: Optional[str] = None
style: Optional[str] = None
response_format: Optional[str] = "b64_json"
response_format: Optional[str] = "url"

View File

@@ -90,7 +90,11 @@ async def chat_completion(
if is_image_chat:
# 图像生成聊天
response = await chat_service.create_image_chat_completion(request, current_api_key)
return response # 直接返回,不处理流式
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
# 非流式直接返回结果
return response
else:
# 普通聊天补全
response = await chat_service.create_chat_completion(request, current_api_key)
@@ -111,8 +115,6 @@ async def generate_image(
operation_name = "generate_image"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling image generation request for prompt: {request.prompt}")
# 注意:这里假设 image_create_service.generate_images 是同步函数
# 如果它是异步的,需要 await
response = image_create_service.generate_images(request)
return response

View File

@@ -2,17 +2,18 @@
import json
import re
import datetime # Add datetime import
import time # Add time import
import datetime
import time
from typing import Any, AsyncGenerator, Dict, List
from app.config.config import settings
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
from app.domain.gemini_models import GeminiRequest
from app.handler.response_handler import GeminiResponseHandler
from app.handler.stream_optimizer import gemini_optimizer
from app.log.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import KeyManager
from app.database.services import add_error_log, add_request_log # Import add_request_log
from app.database.services import add_error_log, add_request_log
logger = get_gemini_logger()
@@ -73,20 +74,8 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
return settings.SAFETY_SETTINGS
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:

View File

@@ -8,6 +8,7 @@ from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from app.config.config import settings
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
from app.database.services import (
add_error_log,
add_request_log,
@@ -102,7 +103,7 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return settings.GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
return settings.SAFETY_SETTINGS

View File

@@ -39,7 +39,7 @@ class EmbeddingService:
client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL)
response = client.embeddings.create(input=input_text, model=model)
is_success = True
status_code = 200 # Assume 200 OK on success
status_code = 200
return response
except APIStatusError as e:
is_success = False