diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index 33832ac..b9945a7 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from copy import deepcopy +import json import asyncio from app.config.config import settings from app.log.logger import get_gemini_logger @@ -159,7 +160,6 @@ async def generate_content( ) return response - @router.post("/models/{model_name}:streamGenerateContent") @router_v1beta.post("/models/{model_name}:streamGenerateContent") @RetryHandler(key_arg="api_key") @@ -181,12 +181,42 @@ async def stream_generate_content( if not await model_service.check_model_support(model_name): raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") - response_stream = chat_service.stream_generate_content( + raw_stream = chat_service.stream_generate_content( model=model_name, request=request, api_key=api_key ) - return StreamingResponse(response_stream, media_type="text/event-stream") + try: + # 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON + first_chunk = await raw_stream.__anext__() + except StopAsyncIteration: + # 如果流直接结束,退回标准 SSE 输出 + return StreamingResponse(raw_stream, media_type="text/event-stream") + except Exception as e: + # 初始化流异常,直接返回 500 错误 + return JSONResponse( + content={"error": {"code": 500, "message": str(e)}}, + status_code=500 + ) + + # 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送 + if isinstance(first_chunk, str) and first_chunk.startswith("data:"): + async def combined(): + yield first_chunk + async for chunk in raw_stream: + yield chunk + + return StreamingResponse(combined(), media_type="text/event-stream") + + # 否则把首块当作错误 JSON 处理 + try: + err = json.loads(first_chunk) + code = err.get("error", {}).get("code", 500) + except json.JSONDecodeError: + err = {"error": {"code": 500, "message": first_chunk}} + code = 500 + + return JSONResponse(content=err, status_code=code) @router.post("/models/{model_name}:countTokens") diff --git a/app/service/chat/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py index 0c9f609..07551b7 100644 --- a/app/service/chat/gemini_chat_service.py +++ b/app/service/chat/gemini_chat_service.py @@ -470,6 +470,7 @@ class GeminiChatService: is_success = False status_code = None final_api_key = api_key + last_error_msg = None while retries < max_retries: request_datetime = datetime.datetime.now() @@ -509,6 +510,7 @@ class GeminiChatService: retries += 1 is_success = False error_log_msg = str(e) + last_error_msg = error_log_msg logger.warning( f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}" ) @@ -553,3 +555,26 @@ class GeminiChatService: latency_ms=latency_ms, request_time=request_datetime, ) + + # Emit final error SSE event if all retries failed + if not is_success: + # 从错误消息中提取嵌套JSON + parsed_error = None + if last_error_msg: + try: + # 查找JSON起始位置 + json_start = last_error_msg.find('{') + if json_start != -1: + json_str = last_error_msg[json_start:] + parsed_error = json.loads(json_str) + except json.JSONDecodeError: + pass + + error_data = { + "error": { + "code": parsed_error['error']['code'] if (parsed_error and 'error' in parsed_error and 'code' in parsed_error['error']) else (status_code or 500), + "message": parsed_error['error']['message'] if (parsed_error and 'error' in parsed_error and 'message' in parsed_error['error']) else (last_error_msg or "Streaming failed"), + "status": parsed_error['error']['status'] if (parsed_error and 'error' in parsed_error and 'status' in parsed_error['error']) else "INTERNAL" + } + } + yield json.dumps(error_data, ensure_ascii=False) \ No newline at end of file diff --git a/app/service/chat/openai_chat_service.py b/app/service/chat/openai_chat_service.py index 24cd12d..bb13dd9 100644 --- a/app/service/chat/openai_chat_service.py +++ b/app/service/chat/openai_chat_service.py @@ -742,4 +742,4 @@ class OpenAIChatService: status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, - ) + ) \ No newline at end of file diff --git a/app/service/chat/vertex_express_chat_service.py b/app/service/chat/vertex_express_chat_service.py index a54ffc9..362b10e 100644 --- a/app/service/chat/vertex_express_chat_service.py +++ b/app/service/chat/vertex_express_chat_service.py @@ -400,4 +400,4 @@ class GeminiChatService: status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, - ) + ) \ No newline at end of file