Files
gemini-balance/app/service/chat/gemini_chat_service.py
4Crusaders f51a4d20ad fix: 修复Gemini模型不支持同时使用tools和结构化输出的问题
- 添加_is_structured_output_request函数检测是否为结构化JSON输出请求
- 当检测到请求指定responseMimeType为application/json时,跳过gemini-balance主动添加的所有工具
- 仅在非结构化输出场景下,gemini-balance才会自动添加codeExecution、googleSearch、urlContext等工具
- 解决"Tool use with a response mime type: 'application/json' is unsupported"错误

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-06 14:07:33 +08:00

511 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# app/services/chat_service.py
import json
import re
import datetime
import time
from typing import Any, AsyncGenerator, Dict, List
from app.config.config import settings
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
from app.domain.gemini_models import GeminiRequest
from app.handler.response_handler import GeminiResponseHandler
from app.handler.stream_optimizer import gemini_optimizer
from app.log.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import KeyManager
from app.database.services import add_error_log, add_request_log, get_file_api_key
from app.utils.helpers import redact_key_for_logging
logger = get_gemini_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]:
"""從內容中提取文件引用"""
file_names = []
for content in contents:
if "parts" in content:
for part in content["parts"]:
if not isinstance(part, dict) or "fileData" not in part:
continue
file_data = part["fileData"]
if "fileUri" not in file_data:
continue
file_uri = file_data["fileUri"]
# 從 URI 中提取文件名
# 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id}
match = re.match(rf"{re.escape(settings.BASE_URL)}/(files/.*)", file_uri)
if not match:
logger.warning(f"Invalid file URI: {file_uri}")
continue
file_id = match.group(1)
file_names.append(file_id)
logger.info(f"Found file reference: {file_id}")
return file_names
def _clean_json_schema_properties(obj: Any) -> Any:
"""清理JSON Schema中Gemini API不支持的字段"""
if not isinstance(obj, dict):
return obj
# Gemini API不支持的JSON Schema字段
unsupported_fields = {
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
"contentEncoding", "contentMediaType", "if", "then", "else",
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
"$id", "$ref", "$comment", "readOnly", "writeOnly"
}
cleaned = {}
for key, value in obj.items():
if key in unsupported_fields:
continue
if isinstance(value, dict):
cleaned[key] = _clean_json_schema_properties(value)
elif isinstance(value, list):
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
else:
cleaned[key] = value
return cleaned
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
"""检查内容中是否包含 functionCall"""
if not contents or not isinstance(contents, list):
return False
for content in contents:
if not content or not isinstance(content, dict) or "parts" not in content:
continue
parts = content.get("parts", [])
if not parts or not isinstance(parts, list):
continue
for part in parts:
if isinstance(part, dict) and "functionCall" in part:
return True
return False
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
record = dict()
for item in tools:
if not item or not isinstance(item, dict):
continue
for k, v in item.items():
if k == "functionDeclarations" and v and isinstance(v, list):
functions = record.get("functionDeclarations", [])
# 清理每个函数声明中的不支持字段
cleaned_functions = []
for func in v:
if isinstance(func, dict):
cleaned_func = _clean_json_schema_properties(func)
cleaned_functions.append(cleaned_func)
else:
cleaned_functions.append(func)
functions.extend(cleaned_functions)
record["functionDeclarations"] = functions
else:
record[k] = v
return record
def _is_structured_output_request(payload: Dict[str, Any]) -> bool:
"""检查请求是否要求结构化JSON输出"""
try:
generation_config = payload.get("generationConfig", {})
return generation_config.get("responseMimeType") == "application/json"
except (AttributeError, TypeError):
return False
tool = dict()
if payload and isinstance(payload, dict) and "tools" in payload:
if payload.get("tools") and isinstance(payload.get("tools"), dict):
payload["tools"] = [payload.get("tools")]
items = payload.get("tools", [])
if items and isinstance(items, list):
tool.update(_merge_tools(items))
# "Tool use with a response mime type: 'application/json' is unsupported"
# Gemini API限制不支持同时使用tools和结构化输出(response_mime_type='application/json')
# 当请求指定了JSON响应格式时跳过所有工具的添加以避免API错误
has_structured_output = _is_structured_output_request(payload)
if not has_structured_output:
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(payload.get("contents", []))
):
tool["codeExecution"] = {}
if model.endswith("-search"):
tool["googleSearch"] = {}
real_model = _get_real_model(model)
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
tool["urlContext"] = {}
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations") or _has_function_call(payload.get("contents", [])):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
tool.pop("urlContext", None)
return [tool] if tool else []
def _get_real_model(model: str) -> str:
if model.endswith("-search"):
model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
if model.endswith("-non-thinking"):
model = model[:-13]
if "-search" in model and "-non-thinking" in model:
model = model[:-20]
return model
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
return settings.SAFETY_SETTINGS
def _filter_empty_parts(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filters out contents with empty or invalid parts."""
if not contents:
return []
filtered_contents = []
for content in contents:
if not content or "parts" not in content or not isinstance(content.get("parts"), list):
continue
valid_parts = [part for part in content["parts"] if isinstance(part, dict) and part]
if valid_parts:
new_content = content.copy()
new_content["parts"] = valid_parts
filtered_contents.append(new_content)
return filtered_contents
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
request_dict = request.model_dump(exclude_none=False)
if request.generationConfig:
if request.generationConfig.maxOutputTokens is None:
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
if "maxOutputTokens" in request_dict["generationConfig"]:
request_dict["generationConfig"].pop("maxOutputTokens")
# 检查是否为TTS模型
is_tts_model = "tts" in model.lower()
if is_tts_model:
# TTS模型使用简化的payload不包含tools和safetySettings
payload = {
"contents": _filter_empty_parts(request_dict.get("contents", [])),
"generationConfig": request_dict.get("generationConfig"),
}
# 只在有systemInstruction时才添加
if request_dict.get("systemInstruction"):
payload["systemInstruction"] = request_dict.get("systemInstruction")
else:
# 非TTS模型使用完整的payload
payload = {
"contents": _filter_empty_parts(request_dict.get("contents", [])),
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": request_dict.get("generationConfig"),
"systemInstruction": request_dict.get("systemInstruction"),
}
# 确保 generationConfig 不为 None
if payload["generationConfig"] is None:
payload["generationConfig"] = {}
if model.endswith("-image") or model.endswith("-image-generation"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
client_thinking_config = None
if request.generationConfig and request.generationConfig.thinkingConfig:
client_thinking_config = request.generationConfig.thinkingConfig
if client_thinking_config is not None:
# 客户端提供了思考配置,直接使用
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
else:
# 客户端没有提供思考配置,使用默认配置
if model.endswith("-non-thinking"):
if "gemini-2.5-pro" in model:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
if settings.SHOW_THINKING_PROCESS:
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000),
"includeThoughts": True
}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
return payload
class GeminiChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
"""从响应中提取文本内容"""
if not response.get("candidates"):
return ""
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if parts and "text" in parts[0]:
return parts[0].get("text", "")
return ""
def _create_char_response(
self, original_response: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的响应"""
response_copy = json.loads(json.dumps(original_response))
if response_copy.get("candidates") and response_copy["candidates"][0].get(
"content", {}
).get("parts"):
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
return response_copy
async def generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""生成内容"""
# 檢查並獲取文件專用的 API key如果有文件
file_names = _extract_file_references(request.model_dump().get("contents", []))
if file_names:
logger.info(f"Request contains file references: {file_names}")
file_api_key = await get_file_api_key(file_names[0])
if file_api_key:
logger.info(f"Found API key for file {file_names[0]}: {redact_key_for_logging(file_api_key)}")
api_key = file_api_key # 使用文件的 API key
else:
logger.warning(f"No API key found for file {file_names[0]}, using default key: {redact_key_for_logging(api_key)}")
payload = _build_payload(model, request)
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.generate_content(payload, model, api_key)
is_success = True
status_code = 200
return self.response_handler.handle_response(response, model, stream=False)
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Normal API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="gemini-chat-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
)
async def count_tokens(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""计算token数量"""
# countTokens API只需要contents
payload = {"contents": _filter_empty_parts(request.model_dump().get("contents", []))}
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.count_tokens(payload, model, api_key)
is_success = True
status_code = 200
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="gemini-count-tokens",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
)
async def stream_generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> AsyncGenerator[str, None]:
"""流式生成内容"""
# 檢查並獲取文件專用的 API key如果有文件
file_names = _extract_file_references(request.model_dump().get("contents", []))
if file_names:
logger.info(f"Request contains file references: {file_names}")
file_api_key = await get_file_api_key(file_names[0])
if file_api_key:
logger.info(f"Found API key for file {file_names[0]}: {redact_key_for_logging(file_api_key)}")
api_key = file_api_key # 使用文件的 API key
else:
logger.warning(f"No API key found for file {file_names[0]}, using default key: {redact_key_for_logging(api_key)}")
retries = 0
max_retries = settings.MAX_RETRIES
payload = _build_payload(model, request)
is_success = False
status_code = None
final_api_key = api_key
while retries < max_retries:
request_datetime = datetime.datetime.now()
start_time = time.perf_counter()
current_attempt_key = api_key
final_api_key = current_attempt_key
try:
async for line in self.api_client.stream_generate_content(
payload, model, current_attempt_key
):
# print(line)
if line.startswith("data:"):
line = line[6:]
response_data = self.response_handler.handle_response(
json.loads(line), model, stream=True
)
text = self._extract_text_from_response(response_data)
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
if text and settings.STREAM_OPTIMIZER_ENABLED:
# 使用流式输出优化器处理文本输出
async for (
optimized_chunk
) in gemini_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_response(response_data, t),
lambda c: "data: " + json.dumps(c) + "\n\n",
):
yield optimized_chunk
else:
# 如果没有文本内容(如工具调用等),整块输出
yield "data: " + json.dumps(response_data) + "\n\n"
logger.info("Streaming completed successfully")
is_success = True
status_code = 200
break
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
)
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
model_name=model,
error_type="gemini-chat-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
)
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
if api_key:
logger.info(f"Switched to new API key: {redact_key_for_logging(api_key)}")
else:
logger.error(f"No valid API key available after {retries} retries.")
break
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming."
)
break
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=final_api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
)