diff --git a/app/domain/openai_models.py b/app/domain/openai_models.py index b3077ad..3916236 100644 --- a/app/domain/openai_models.py +++ b/app/domain/openai_models.py @@ -12,6 +12,7 @@ class ChatRequest(BaseModel): max_tokens: Optional[int] = None top_p: Optional[float] = DEFAULT_TOP_P top_k: Optional[int] = DEFAULT_TOP_K + n: Optional[int] = 1 stop: Optional[Union[List[str],str]] = None reasoning_effort: Optional[str] = None tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = [] diff --git a/app/handler/response_handler.py b/app/handler/response_handler.py index 7bb2d76..1300e71 100644 --- a/app/handler/response_handler.py +++ b/app/handler/response_handler.py @@ -42,21 +42,35 @@ class GeminiResponseHandler(ResponseHandler): def _handle_openai_stream_response( response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]] ) -> Dict[str, Any]: - text, reasoning_content, tool_calls, _ = _extract_result( - response, model, stream=True, gemini_format=False - ) - if not text and not tool_calls and not reasoning_content: - delta = {} - else: - delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"} - if tool_calls: - delta["tool_calls"] = tool_calls + choices = [] + candidates = response.get("candidates", []) + + for candidate in candidates: + index = candidate.get("index", 0) + text, reasoning_content, tool_calls, _ = _extract_result( + {"candidates": [candidate]}, model, stream=True, gemini_format=False + ) + + if not text and not tool_calls and not reasoning_content: + delta = {} + else: + delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"} + if tool_calls: + delta["tool_calls"] = tool_calls + + choice = { + "index": index, + "delta": delta, + "finish_reason": finish_reason + } + choices.append(choice) + template_chunk = { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model, - "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], + "choices": choices, } if usage_metadata: template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)} @@ -66,26 +80,31 @@ def _handle_openai_stream_response( def _handle_openai_normal_response( response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]] ) -> Dict[str, Any]: - text, reasoning_content, tool_calls, _ = _extract_result( - response, model, stream=False, gemini_format=False - ) + choices = [] + candidates = response.get("candidates", []) + + for i, candidate in enumerate(candidates): + text, reasoning_content, tool_calls, _ = _extract_result( + {"candidates": [candidate]}, model, stream=False, gemini_format=False + ) + choice = { + "index": i, + "message": { + "role": "assistant", + "content": text, + "reasoning_content": reasoning_content, + "tool_calls": tool_calls, + }, + "finish_reason": finish_reason, + } + choices.append(choice) + return { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": text, - "reasoning_content": reasoning_content, - "tool_calls": tool_calls, - }, - "finish_reason": finish_reason, - } - ], + "choices": choices, "usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}, } diff --git a/app/service/chat/openai_chat_service.py b/app/service/chat/openai_chat_service.py index b627fc0..fb09af4 100644 --- a/app/service/chat/openai_chat_service.py +++ b/app/service/chat/openai_chat_service.py @@ -196,6 +196,10 @@ def _build_payload( # 处理 max_tokens 参数 _validate_and_set_max_tokens(payload, request.max_tokens, logger) + + # 处理 n 参数 + if request.n is not None and request.n > 0: + payload["generationConfig"]["candidateCount"] = request.n if request.model.endswith("-image") or request.model.endswith("-image-generation"): payload["generationConfig"]["responseModalities"] = ["Text", "Image"]