mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-11 10:39:50 +08:00
- 为 Gemini 聊天(流式/非流式)、OpenAI 图像聊天(流式/非流式)和 embedding 服务的 API 调用实现全面的数据库日志记录。日志包括请求详情、成功/失败状态、状态码、延迟和错误消息。 - 重构 Gemini 流式聊天服务 (`stream_generate_content`) 以整合使用 `KeyManager` 的重试逻辑,与非流式实现保持一致,包括失败时的 API 密钥切换。 - 增强重试处理器 (`RetryHandler`) 的日志记录,以提高密钥切换和失败场景下的清晰度。 - 确保 `api_key` 正确传递给 OpenAI 图像聊天完成。 - 改进 embedding 服务中的错误处理,区分 `APIStatusError` 和通用异常,并将错误记录到数据库。 - 为 embedding 服务日志添加请求负载截断。 - 修复 Gemini `_build_payload` 中使用正确的 `model` 变量获取 `THINKING_BUDGET_MAP` 的错误。 - 移除 `ImageCreateService` 中未使用的 `paid_key` 类变量。
487 lines
21 KiB
Python
487 lines
21 KiB
Python
# app/services/chat_service.py
|
||
|
||
import json
|
||
import re
|
||
import datetime # Add datetime import
|
||
import time # Add time import
|
||
from copy import deepcopy
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||
|
||
from app.config.config import settings
|
||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||
from app.handler.message_converter import OpenAIMessageConverter
|
||
from app.handler.response_handler import OpenAIResponseHandler
|
||
from app.handler.stream_optimizer import openai_optimizer
|
||
from app.log.logger import get_openai_logger
|
||
from app.service.client.api_client import GeminiApiClient
|
||
from app.service.image.image_create_service import ImageCreateService
|
||
from app.service.key.key_manager import KeyManager
|
||
from app.database.services import add_error_log, add_request_log # Import add_request_log
|
||
|
||
logger = get_openai_logger()
|
||
|
||
|
||
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||
"""判断消息是否包含图片部分"""
|
||
for content in contents:
|
||
if "parts" in content:
|
||
for part in content["parts"]:
|
||
if "image_url" in part or "inline_data" in part:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _build_tools(
|
||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||
) -> List[Dict[str, Any]]:
|
||
"""构建工具"""
|
||
tool = dict()
|
||
model = request.model
|
||
|
||
if (
|
||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||
and not (
|
||
model.endswith("-search")
|
||
or "-thinking" in model
|
||
or model.endswith("-image")
|
||
or model.endswith("-image-generation")
|
||
)
|
||
and not _has_image_parts(messages)
|
||
):
|
||
tool["codeExecution"] = {}
|
||
if model.endswith("-search"):
|
||
tool["googleSearch"] = {}
|
||
|
||
# 将 request 中的 tools 合并到 tools 中
|
||
if request.tools:
|
||
function_declarations = []
|
||
for item in request.tools:
|
||
if not item or not isinstance(item, dict):
|
||
continue
|
||
|
||
if item.get("type", "") == "function" and item.get("function"):
|
||
function = deepcopy(item.get("function"))
|
||
parameters = function.get("parameters", {})
|
||
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
||
function.pop("parameters", None)
|
||
|
||
function_declarations.append(function)
|
||
|
||
if function_declarations:
|
||
# 按照 function 的 name 去重
|
||
names, functions = set(), []
|
||
for fc in function_declarations:
|
||
if fc.get("name") not in names:
|
||
names.add(fc.get("name"))
|
||
functions.append(fc)
|
||
|
||
tool["functionDeclarations"] = functions
|
||
|
||
# 解决 "Tool use with function calling is unsupported" 问题
|
||
if tool.get("functionDeclarations"):
|
||
tool.pop("googleSearch", None)
|
||
tool.pop("codeExecution", None)
|
||
|
||
return [tool] if tool else []
|
||
|
||
|
||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||
"""获取安全设置"""
|
||
# if (
|
||
# "2.0" in model
|
||
# and "gemini-2.0-flash-thinking-exp" not in model
|
||
# and "gemini-2.0-pro-exp" not in model
|
||
# ):
|
||
if model == "gemini-2.0-flash-exp":
|
||
return [
|
||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||
]
|
||
return [
|
||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||
]
|
||
|
||
|
||
def _build_payload(
|
||
request: ChatRequest,
|
||
messages: List[Dict[str, Any]],
|
||
instruction: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""构建请求payload"""
|
||
payload = {
|
||
"contents": messages,
|
||
"generationConfig": {
|
||
"temperature": request.temperature,
|
||
"stopSequences": request.stop,
|
||
"topP": request.top_p,
|
||
"topK": request.top_k,
|
||
},
|
||
"tools": _build_tools(request, messages),
|
||
"safetySettings": _get_safety_settings(request.model),
|
||
}
|
||
if request.max_tokens is not None:
|
||
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||
if request.model.endswith("-non-thinking"):
|
||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||
if request.model in settings.THINKING_BUDGET_MAP:
|
||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model,1000)}
|
||
|
||
if (
|
||
instruction
|
||
and isinstance(instruction, dict)
|
||
and instruction.get("role") == "system"
|
||
and instruction.get("parts")
|
||
and not request.model.endswith("-image")
|
||
and not request.model.endswith("-image-generation")
|
||
):
|
||
payload["systemInstruction"] = instruction
|
||
|
||
return payload
|
||
|
||
|
||
class OpenAIChatService:
|
||
"""聊天服务"""
|
||
|
||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||
self.message_converter = OpenAIMessageConverter()
|
||
self.response_handler = OpenAIResponseHandler(config=None)
|
||
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
||
self.key_manager = key_manager
|
||
self.image_create_service = ImageCreateService()
|
||
|
||
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
|
||
"""从OpenAI响应块中提取文本内容"""
|
||
if not chunk.get("choices"):
|
||
return ""
|
||
|
||
choice = chunk["choices"][0]
|
||
if "delta" in choice and "content" in choice["delta"]:
|
||
return choice["delta"]["content"]
|
||
return ""
|
||
|
||
def _create_char_openai_chunk(
|
||
self, original_chunk: Dict[str, Any], text: str
|
||
) -> Dict[str, Any]:
|
||
"""创建包含指定文本的OpenAI响应块"""
|
||
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
|
||
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
||
chunk_copy["choices"][0]["delta"]["content"] = text
|
||
return chunk_copy
|
||
|
||
async def create_chat_completion(
|
||
self,
|
||
request: ChatRequest,
|
||
api_key: str,
|
||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||
"""创建聊天完成"""
|
||
# 转换消息格式
|
||
messages, instruction = self.message_converter.convert(request.messages)
|
||
|
||
# 构建请求payload
|
||
payload = _build_payload(request, messages, instruction)
|
||
|
||
if request.stream:
|
||
return self._handle_stream_completion(request.model, payload, api_key)
|
||
return await self._handle_normal_completion(request.model, payload, api_key)
|
||
|
||
async def _handle_normal_completion(
|
||
self, model: str, payload: Dict[str, Any], api_key: str
|
||
) -> Dict[str, Any]:
|
||
"""处理普通聊天完成"""
|
||
start_time = time.perf_counter()
|
||
request_datetime = datetime.datetime.now()
|
||
is_success = False
|
||
status_code = None
|
||
response = None
|
||
try:
|
||
response = await self.api_client.generate_content(payload, model, api_key)
|
||
is_success = True
|
||
status_code = 200 # Assume 200 on success
|
||
return self.response_handler.handle_response(
|
||
response, model, stream=False, finish_reason="stop"
|
||
)
|
||
except Exception as e:
|
||
is_success = False
|
||
error_log_msg = str(e)
|
||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||
# Try to parse status code from exception
|
||
match = re.search(r"status code (\d+)", error_log_msg)
|
||
if match:
|
||
status_code = int(match.group(1))
|
||
else:
|
||
status_code = 500 # Default if parsing fails
|
||
|
||
await add_error_log(
|
||
gemini_key=api_key, # Note: Parameter name is gemini_key in add_error_log
|
||
model_name=model,
|
||
error_type="openai-chat-non-stream",
|
||
error_log=error_log_msg,
|
||
error_code=status_code,
|
||
request_msg=payload
|
||
)
|
||
raise e # Re-throw exception
|
||
finally:
|
||
end_time = time.perf_counter()
|
||
latency_ms = int((end_time - start_time) * 1000)
|
||
await add_request_log(
|
||
model_name=model,
|
||
api_key=api_key,
|
||
is_success=is_success,
|
||
status_code=status_code,
|
||
latency_ms=latency_ms,
|
||
request_time=request_datetime
|
||
)
|
||
|
||
async def _handle_stream_completion(
|
||
self, model: str, payload: Dict[str, Any], api_key: str
|
||
) -> AsyncGenerator[str, None]:
|
||
"""处理流式聊天完成,添加重试逻辑"""
|
||
retries = 0
|
||
max_retries = settings.MAX_RETRIES
|
||
is_success = False
|
||
status_code = None
|
||
final_api_key = api_key
|
||
|
||
while retries < max_retries:
|
||
start_time = time.perf_counter()
|
||
request_datetime = datetime.datetime.now()
|
||
current_attempt_key = api_key
|
||
final_api_key = current_attempt_key
|
||
try:
|
||
tool_call_flag = False
|
||
async for line in self.api_client.stream_generate_content(
|
||
payload, model, current_attempt_key
|
||
):
|
||
if line.startswith("data:"):
|
||
chunk = json.loads(line[6:])
|
||
openai_chunk = self.response_handler.handle_response(
|
||
chunk, model, stream=True, finish_reason=None
|
||
)
|
||
if openai_chunk:
|
||
# 提取文本内容
|
||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
||
# 使用流式输出优化器处理文本输出
|
||
async for (
|
||
optimized_chunk
|
||
) in openai_optimizer.optimize_stream_output(
|
||
text,
|
||
lambda t: self._create_char_openai_chunk(
|
||
openai_chunk, t
|
||
),
|
||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||
):
|
||
yield optimized_chunk
|
||
else:
|
||
# 如果没有文本内容(如工具调用等),整块输出
|
||
if "tool_calls" in json.dumps(openai_chunk):
|
||
tool_call_flag = True
|
||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||
if tool_call_flag:
|
||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls'))}\n\n"
|
||
else:
|
||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
logger.info("Streaming completed successfully")
|
||
is_success = True
|
||
status_code = 200 # Assume 200 on success
|
||
break # 成功后退出循环
|
||
except Exception as e:
|
||
retries += 1
|
||
is_success = False
|
||
error_log_msg = str(e)
|
||
logger.warning(
|
||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||
)
|
||
# Parse error code for logging
|
||
match = re.search(r"status code (\d+)", error_log_msg)
|
||
if match:
|
||
status_code = int(match.group(1))
|
||
else:
|
||
status_code = 500 # Default if parsing fails
|
||
|
||
# Log error to error log table
|
||
await add_error_log(
|
||
gemini_key=current_attempt_key,
|
||
model_name=model,
|
||
error_type="openai-chat-stream",
|
||
error_log=error_log_msg,
|
||
error_code=status_code,
|
||
request_msg=payload
|
||
)
|
||
|
||
# Attempt to switch API Key
|
||
# Ensure key_manager is available (might need adjustment if not always passed)
|
||
if self.key_manager:
|
||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||
if api_key:
|
||
logger.info(f"Switched to new API key: {api_key}")
|
||
else:
|
||
logger.error(f"No valid API key available after {retries} retries.")
|
||
break # Exit loop if no key available
|
||
else:
|
||
logger.error("KeyManager not available for retry logic.")
|
||
break # Exit loop if key manager is missing
|
||
|
||
if retries >= max_retries:
|
||
logger.error(
|
||
f"Max retries ({max_retries}) reached for streaming."
|
||
)
|
||
break # Exit loop after max retries
|
||
finally:
|
||
# Log the final outcome of the streaming request
|
||
end_time = time.perf_counter()
|
||
latency_ms = int((end_time - start_time) * 1000)
|
||
await add_request_log(
|
||
model_name=model,
|
||
api_key=final_api_key, # Log the last key used
|
||
is_success=is_success, # Log the final success status
|
||
status_code=status_code, # Log the last known status code
|
||
latency_ms=latency_ms, # Log total time including retries
|
||
request_time=request_datetime
|
||
)
|
||
# If the loop finished due to failure, yield error and DONE
|
||
if not is_success and retries >= max_retries:
|
||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
|
||
async def create_image_chat_completion(
|
||
self,
|
||
request: ChatRequest,
|
||
api_key: str
|
||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||
|
||
image_generate_request = ImageGenerationRequest()
|
||
image_generate_request.prompt = request.messages[-1]["content"]
|
||
image_res = self.image_create_service.generate_images_chat(
|
||
image_generate_request
|
||
)
|
||
|
||
if request.stream:
|
||
return self._handle_stream_image_completion(request.model, image_res, api_key)
|
||
else:
|
||
return await self._handle_normal_image_completion(request.model, image_res, api_key)
|
||
|
||
async def _handle_stream_image_completion(
|
||
self, model: str, image_data: str, api_key:str
|
||
) -> AsyncGenerator[str, None]:
|
||
logger.info(f"Starting stream image completion for model: {model}")
|
||
start_time = time.perf_counter()
|
||
request_datetime = datetime.datetime.now() # Although not used for DB log here
|
||
is_success = False
|
||
status_code = None # Although not used for DB log here
|
||
|
||
try:
|
||
if image_data:
|
||
openai_chunk = self.response_handler.handle_image_chat_response(
|
||
image_data, model, stream=True, finish_reason=None
|
||
)
|
||
if openai_chunk:
|
||
# 提取文本内容
|
||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||
if text:
|
||
# 使用流式输出优化器处理文本输出
|
||
async for (
|
||
optimized_chunk
|
||
) in openai_optimizer.optimize_stream_output(
|
||
text,
|
||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||
):
|
||
yield optimized_chunk
|
||
else:
|
||
# 如果没有文本内容(如图片URL等),整块输出
|
||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||
logger.info(f"Stream image completion finished successfully for model: {model}")
|
||
is_success = True
|
||
status_code = 200
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
is_success = False
|
||
error_log_msg = f"Stream image completion failed for model {model}: {e}"
|
||
logger.error(error_log_msg)
|
||
status_code = 500 # Default error code
|
||
# Call add_error_log using the passed api_key
|
||
await add_error_log(
|
||
gemini_key=api_key,
|
||
model_name=model,
|
||
error_type="openai-image-stream", # Specific error type
|
||
error_log=error_log_msg,
|
||
error_code=status_code,
|
||
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
|
||
)
|
||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n" # Send error to client
|
||
yield "data: [DONE]\n\n" # Still need DONE message
|
||
# Re-raising might break the stream, decide if needed
|
||
finally:
|
||
end_time = time.perf_counter()
|
||
latency_ms = int((end_time - start_time) * 1000)
|
||
logger.info(f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}")
|
||
# Call add_request_log using the passed api_key
|
||
await add_request_log(
|
||
model_name=model,
|
||
api_key=api_key,
|
||
is_success=is_success,
|
||
status_code=status_code,
|
||
latency_ms=latency_ms,
|
||
request_time=request_datetime
|
||
)
|
||
|
||
async def _handle_normal_image_completion(
|
||
self, model: str, image_data: str, api_key: str # Add api_key parameter
|
||
) -> Dict[str, Any]:
|
||
logger.info(f"Starting normal image completion for model: {model}")
|
||
start_time = time.perf_counter()
|
||
request_datetime = datetime.datetime.now() # Although not used for DB log here
|
||
is_success = False
|
||
status_code = None # Although not used for DB log here
|
||
result = None
|
||
|
||
try:
|
||
result = self.response_handler.handle_image_chat_response(
|
||
image_data, model, stream=False, finish_reason="stop"
|
||
)
|
||
logger.info(f"Normal image completion finished successfully for model: {model}")
|
||
is_success = True
|
||
status_code = 200
|
||
return result
|
||
except Exception as e:
|
||
is_success = False
|
||
error_log_msg = f"Normal image completion failed for model {model}: {e}"
|
||
logger.error(error_log_msg)
|
||
status_code = 500 # Default error code
|
||
# Call add_error_log using the passed api_key
|
||
await add_error_log(
|
||
gemini_key=api_key,
|
||
model_name=model,
|
||
error_type="openai-image-non-stream", # Specific error type
|
||
error_log=error_log_msg,
|
||
error_code=status_code,
|
||
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
|
||
)
|
||
# Re-raise the exception so the caller knows about the failure
|
||
raise e
|
||
finally:
|
||
end_time = time.perf_counter()
|
||
latency_ms = int((end_time - start_time) * 1000)
|
||
logger.info(f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}")
|
||
# Call add_request_log using the passed api_key
|
||
await add_request_log(
|
||
model_name=model,
|
||
api_key=api_key,
|
||
is_success=is_success,
|
||
status_code=status_code,
|
||
latency_ms=latency_ms,
|
||
request_time=request_datetime
|
||
)
|