feat: 增强流式响应处理,支持使用元数据

本次提交对流式响应处理进行了增强,主要变更包括:

- **参数更新**:
  - 在 `_handle_openai_stream_response` 方法中新增 `usage_metadata` 参数,以支持传递使用情况的元数据。

- **数据结构调整**:
  - 在返回的响应中,若提供了 `usage_metadata`,则将其包含在返回的 JSON 结构中,确保更全面的响应信息。

- **伪流式逻辑更新**:
  - 在 `OpenAIChatService` 中的多个方法中,更新了对流式响应的调用,确保在处理响应时也能传递和使用元数据。

这些更改旨在提升流式响应的灵活性和信息丰富性,改善用户体验。
This commit is contained in:
snaily
2025-05-09 18:57:10 +08:00
parent c85fe979e5
commit 11e45fca37
2 changed files with 36 additions and 105 deletions

View File

@@ -37,7 +37,7 @@ class GeminiResponseHandler(ResponseHandler):
def _handle_openai_stream_response(
response: Dict[str, Any], model: str, finish_reason: str
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=True, gemini_format=False
@@ -48,14 +48,16 @@ def _handle_openai_stream_response(
delta = {"content": text, "role": "assistant"}
if tool_calls:
delta["tool_calls"] = tool_calls
return {
template_chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
}
if usage_metadata:
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
return template_chunk
def _handle_openai_normal_response(
@@ -101,7 +103,7 @@ class OpenAIResponseHandler(ResponseHandler):
usage_metadata: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
if stream:
return _handle_openai_stream_response(response, model, finish_reason)
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
def handle_image_chat_response(

View File

@@ -261,13 +261,7 @@ class OpenAIChatService:
while keep_sending_empty_data:
await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS)
if keep_sending_empty_data:
empty_chunk = {
"id": f"chatcmpl-fake-heartbeat-{model}-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": None}],
}
empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
yield f"data: {json.dumps(empty_chunk)}\n\n"
logger.debug("Sent empty data chunk for fake stream heartbeat.")
@@ -284,63 +278,25 @@ class OpenAIChatService:
)
yield next_empty_chunk
except asyncio.TimeoutError:
pass # Check api_response_task again
pass
except (
StopAsyncIteration
): # Should not happen if keep_sending_empty_data is managed
):
break
response = await api_response_task # Get API response or exception
response = await api_response_task
finally:
keep_sending_empty_data = False # Stop sending empty data
# Helper to create a base chunk for various scenarios
def create_base_chunk(role_content=""):
return {
"id": f"chatcmpl-fake-response-{model}-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": role_content},
"finish_reason": None,
}
],
}
keep_sending_empty_data = False
if response and response.get("candidates"):
candidate = response["candidates"][0]
if candidate.get("content") and candidate["content"].get("parts"):
full_text = "".join(
part.get("text", "")
for part in candidate["content"]["parts"]
if part.get("text")
)
base_chunk_for_text = create_base_chunk()
final_chunk = self._create_char_openai_chunk(
base_chunk_for_text, full_text
)
final_chunk["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(final_chunk)}\n\n"
logger.info(f"Sent full response content for fake stream: {model}")
else:
logger.warning(
f"Unexpected response structure (no parts/text) in fake stream for model {model}: {response}"
)
base_chunk_for_empty = create_base_chunk()
empty_final_chunk = self._create_char_openai_chunk(
base_chunk_for_empty, ""
)
empty_final_chunk["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(empty_final_chunk)}\n\n"
response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {}))
yield f"data: {json.dumps(response)}\n\n"
logger.info(f"Sent full response content for fake stream: {model}")
else:
error_message = "Failed to get response from model"
if (
response and isinstance(response, dict) and response.get("error")
): # Check if response itself is an error structure
# Safely access nested 'message'
):
error_details = response.get("error")
if isinstance(error_details, dict):
error_message = error_details.get("message", error_message)
@@ -348,11 +304,7 @@ class OpenAIChatService:
logger.error(
f"No candidates or error in response for fake stream model {model}: {response}"
)
base_chunk_for_error = create_base_chunk()
error_chunk = self._create_char_openai_chunk(
base_chunk_for_error, json.dumps({"error": error_message})
)
error_chunk["choices"][0]["finish_reason"] = "stop"
error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
yield f"data: {json.dumps(error_chunk)}\n\n"
async def _real_stream_logic_impl(
@@ -360,26 +312,27 @@ class OpenAIChatService:
) -> AsyncGenerator[str, None]:
"""处理真实流式 (real stream) 的核心逻辑"""
tool_call_flag = False
usage_metadata = None
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
if line.startswith("data:"):
chunk_str = line[6:]
if not chunk_str or chunk_str.isspace(): # handle empty data part
if not chunk_str or chunk_str.isspace():
logger.debug(
f"Received empty data line for model {model}, skipping."
)
continue
try:
chunk = json.loads(chunk_str)
usage_metadata = chunk.get("usageMetadata", {})
except json.JSONDecodeError:
logger.error(
f"Failed to decode JSON from stream for model {model}: {chunk_str}"
)
continue # Skip malformed chunk
continue
openai_chunk = self.response_handler.handle_response(
chunk, model, stream=True, finish_reason=None
chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata
)
if openai_chunk:
text = self._extract_text_from_openai_chunk(openai_chunk)
@@ -393,24 +346,15 @@ class OpenAIChatService:
):
yield optimized_chunk_data
else:
# Check for tool_calls more robustly
if openai_chunk.get("choices") and openai_chunk["choices"][
0
].get("delta", {}).get("tool_calls"):
tool_call_flag = True
elif openai_chunk.get("choices") and openai_chunk["choices"][
0
].get("delta", {}).get(
"function_call"
): # For older compatibility
if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"):
tool_call_flag = True
yield f"data: {json.dumps(openai_chunk)}\n\n"
if tool_call_flag:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls'))}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n"
else:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n"
async def _handle_stream_completion(
self, model: str, payload: Dict[str, Any], api_key: str
@@ -420,14 +364,12 @@ class OpenAIChatService:
max_retries = settings.MAX_RETRIES
is_success = False
status_code = None
final_api_key = api_key # Initialize with the provided API key
final_api_key = api_key
while retries < max_retries:
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
current_attempt_key = (
final_api_key # Use the potentially updated key for this attempt
)
current_attempt_key = final_api_key
try:
stream_generator = None
@@ -449,19 +391,17 @@ class OpenAIChatService:
async for chunk_data in stream_generator:
yield chunk_data
# If the generator completes, it means all its data chunks (including stop/tool_calls) were yielded.
# Now, we send the [DONE] marker for the stream.
yield "data: [DONE]\n\n"
logger.info(
f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}"
)
is_success = True
status_code = 200
break # Successful attempt, exit retry loop
break
except Exception as e:
retries += 1
is_success = False # Ensure is_success is false for this attempt
is_success = False
error_log_msg = str(e)
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
@@ -471,15 +411,10 @@ class OpenAIChatService:
if match:
status_code = int(match.group(1))
else:
# Distinguish between client-side (e.g., asyncio.TimeoutError) and potential API errors
if isinstance(
e, asyncio.TimeoutError
): # Example, can add more specific client errors
status_code = 408 # Request Timeout
if isinstance(e, asyncio.TimeoutError):
status_code = 408
else:
status_code = (
500 # Internal Server Error as default for other exceptions
)
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
@@ -495,7 +430,7 @@ class OpenAIChatService:
current_attempt_key, retries
)
if new_api_key and new_api_key != current_attempt_key:
final_api_key = new_api_key # Update for the NEXT attempt
final_api_key = new_api_key
logger.info(
f"Switched to new API key for next attempt: {final_api_key}"
)
@@ -503,36 +438,30 @@ class OpenAIChatService:
logger.error(
f"No valid API key available after {retries} retries, ceasing attempts for this request."
)
break # No new key, stop retrying
# If new_api_key is the same as current_attempt_key, continue retrying with it if retries < max_retries
break
else:
logger.error(
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
)
break # No KeyManager, stop retrying
break
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming model {model}."
)
# The loop will terminate, and the final error handling outside the loop will take over.
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
# Log with the key used for THIS specific attempt
await add_request_log(
model_name=model,
api_key=current_attempt_key,
is_success=is_success, # This reflects the success of the current attempt
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
# After the loop, if not successful, yield a final error message and [DONE]
if (
not is_success
): # This 'is_success' is the overall success status after all retries
if not is_success:
logger.error(
f"Streaming failed permanently for model {model} after {retries} attempts."
)