mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-02 22:30:21 +08:00
feat: 支持图像生成流式响应并优化配置
- 为 OpenAI 兼容路由的图像生成聊天添加流式支持。 - 重构 `gemini-2.0-flash-exp` 安全设置,使用常量统一管理。 - 更改图像生成默认响应格式为 `url`。 - 启用 `.env.example` 中的 `AUTH_TOKEN`。 - 清理部分代码注释。
This commit is contained in:
@@ -6,7 +6,7 @@ MYSQL_PASSWORD=change_me
|
||||
MYSQL_DATABASE=default_db
|
||||
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
|
||||
ALLOWED_TOKENS=["sk-123456"]
|
||||
# AUTH_TOKEN=sk-123456
|
||||
AUTH_TOKEN=sk-123456
|
||||
TEST_MODEL=gemini-1.5-flash
|
||||
THINKING_MODELS=["gemini-2.5-flash-preview-04-17"]
|
||||
THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
|
||||
|
||||
@@ -32,4 +32,4 @@ class ImageGenerationRequest(BaseModel):
|
||||
size: Optional[str] = "1024x1024"
|
||||
quality: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
response_format: Optional[str] = "b64_json"
|
||||
response_format: Optional[str] = "url"
|
||||
|
||||
@@ -90,7 +90,11 @@ async def chat_completion(
|
||||
if is_image_chat:
|
||||
# 图像生成聊天
|
||||
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
||||
return response # 直接返回,不处理流式
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
# 非流式直接返回结果
|
||||
return response
|
||||
else:
|
||||
# 普通聊天补全
|
||||
response = await chat_service.create_chat_completion(request, current_api_key)
|
||||
@@ -111,8 +115,6 @@ async def generate_image(
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
# 注意:这里假设 image_create_service.generate_images 是同步函数
|
||||
# 如果它是异步的,需要 await
|
||||
response = image_create_service.generate_images(request)
|
||||
return response
|
||||
|
||||
|
||||
@@ -2,17 +2,18 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
import datetime # Add datetime import
|
||||
import time # Add time import
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log # Import add_request_log
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
@@ -73,20 +74,8 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
|
||||
@@ -8,6 +8,7 @@ from copy import deepcopy
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.database.services import (
|
||||
add_error_log,
|
||||
add_request_log,
|
||||
@@ -102,7 +103,7 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
# and "gemini-2.0-pro-exp" not in model
|
||||
# ):
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
return settings.GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class EmbeddingService:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL)
|
||||
response = client.embeddings.create(input=input_text, model=model)
|
||||
is_success = True
|
||||
status_code = 200 # Assume 200 OK on success
|
||||
status_code = 200
|
||||
return response
|
||||
except APIStatusError as e:
|
||||
is_success = False
|
||||
|
||||
Reference in New Issue
Block a user