mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-01 05:39:47 +08:00
Merge pull request #347 from bbbugg:Add-final-SSE-error
Fix: Gemini streaming returns a structured error instead of empty responses
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
import json
|
||||
import asyncio
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
@@ -159,7 +160,6 @@ async def generate_content(
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
@@ -181,12 +181,42 @@ async def stream_generate_content(
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
raw_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_stream, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": 500, "message": str(e)}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_stream:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
|
||||
# 否则把首块当作错误 JSON 处理
|
||||
try:
|
||||
err = json.loads(first_chunk)
|
||||
code = err.get("error", {}).get("code", 500)
|
||||
except json.JSONDecodeError:
|
||||
err = {"error": {"code": 500, "message": first_chunk}}
|
||||
code = 500
|
||||
|
||||
return JSONResponse(content=err, status_code=code)
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:countTokens")
|
||||
|
||||
@@ -470,6 +470,7 @@ class GeminiChatService:
|
||||
is_success = False
|
||||
status_code = None
|
||||
final_api_key = api_key
|
||||
last_error_msg = None
|
||||
|
||||
while retries < max_retries:
|
||||
request_datetime = datetime.datetime.now()
|
||||
@@ -509,6 +510,7 @@ class GeminiChatService:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
last_error_msg = error_log_msg
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
@@ -553,3 +555,26 @@ class GeminiChatService:
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
# Emit final error SSE event if all retries failed
|
||||
if not is_success:
|
||||
# 从错误消息中提取嵌套JSON
|
||||
parsed_error = None
|
||||
if last_error_msg:
|
||||
try:
|
||||
# 查找JSON起始位置
|
||||
json_start = last_error_msg.find('{')
|
||||
if json_start != -1:
|
||||
json_str = last_error_msg[json_start:]
|
||||
parsed_error = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
error_data = {
|
||||
"error": {
|
||||
"code": parsed_error['error']['code'] if (parsed_error and 'error' in parsed_error and 'code' in parsed_error['error']) else (status_code or 500),
|
||||
"message": parsed_error['error']['message'] if (parsed_error and 'error' in parsed_error and 'message' in parsed_error['error']) else (last_error_msg or "Streaming failed"),
|
||||
"status": parsed_error['error']['status'] if (parsed_error and 'error' in parsed_error and 'status' in parsed_error['error']) else "INTERNAL"
|
||||
}
|
||||
}
|
||||
yield json.dumps(error_data, ensure_ascii=False)
|
||||
@@ -742,4 +742,4 @@ class OpenAIChatService:
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
)
|
||||
@@ -400,4 +400,4 @@ class GeminiChatService:
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user