mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-05 14:51:31 +08:00
feat: 优化 Gemini 模型配置和请求参数
This commit is contained in:
@@ -53,13 +53,8 @@ async def chat_completion(
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
response = await chat_service.create_chat_completion(
|
||||
messages=request.messages,
|
||||
model=request.model,
|
||||
temperature=request.temperature,
|
||||
stream=request.stream,
|
||||
request=request,
|
||||
api_key=api_key,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
|
||||
# 处理流式响应
|
||||
|
||||
@@ -8,7 +8,7 @@ class Settings(BaseSettings):
|
||||
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
|
||||
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
@@ -8,7 +8,10 @@ class ChatRequest(BaseModel):
|
||||
temperature: Optional[float] = 0.7
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[List[dict]] = []
|
||||
tool_choice: Optional[str] = "auto"
|
||||
max_tokens: Optional[int] = 8192
|
||||
stop: Optional[List[str]] = []
|
||||
top_p: Optional[float] = 0.9
|
||||
top_k: Optional[int] = 40
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Dict, Any, Optional, AsyncGenerator, Union
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_chat_logger
|
||||
from app.schemas.gemini_models import GeminiRequest
|
||||
from app.schemas.openai_models import ChatRequest
|
||||
|
||||
logger = get_chat_logger()
|
||||
|
||||
@@ -49,9 +50,8 @@ class ChatService:
|
||||
# 处理普通URL图片
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": image_url,
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -61,7 +61,11 @@ class ChatService:
|
||||
return converted_messages
|
||||
|
||||
def convert_gemini_response_to_openai(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, finish_reason: str = None
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Convert Gemini response to OpenAI format"""
|
||||
if stream:
|
||||
@@ -78,11 +82,28 @@ class ChatService:
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = self.format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = self.format_execution_result(parts[0]["executableCodeResult"])
|
||||
text = self.format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = self.format_execution_result(parts[0]["codeExecutionResult"])
|
||||
text = self.format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
groundingChunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
text += f"**【引用来源】**\n\n"
|
||||
for _, groundingChunk in enumerate(groundingChunks, 1):
|
||||
if "web" in groundingChunk:
|
||||
text += self.create_search_link(groundingChunk["web"])
|
||||
else:
|
||||
text = ""
|
||||
|
||||
@@ -104,66 +125,87 @@ class ChatService:
|
||||
logger.debug(f"Raw response: {response}")
|
||||
return None
|
||||
else:
|
||||
return {
|
||||
res = {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0][
|
||||
"text"
|
||||
],
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0]["text"],
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
try:
|
||||
if response.get("candidates"):
|
||||
text = response["candidates"][0]["content"]["parts"][0]["text"]
|
||||
candidate = response["candidates"][0]
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
groundingChunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
text += f"**【引用来源】**\n\n"
|
||||
for _, groundingChunk in enumerate(groundingChunks, 1):
|
||||
if "web" in groundingChunk:
|
||||
text += self.create_search_link(groundingChunk["web"])
|
||||
res["choices"][0]["message"]["content"] = text
|
||||
return res
|
||||
else:
|
||||
res["choices"][0]["message"]["content"] = "暂无返回"
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting Gemini response: {str(e)}")
|
||||
logger.debug(f"Raw response: {response}")
|
||||
res["choices"][0]["message"]["content"] = f"Error converting Gemini response: {str(e)}"
|
||||
return res
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
stream: bool,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Create chat completion using either Gemini or OpenAI API"""
|
||||
|
||||
model = request.model
|
||||
tools = request.tools
|
||||
if tools is None:
|
||||
tools = []
|
||||
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (model.endswith("-search") or "-thinking" in model):
|
||||
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
|
||||
model.endswith("-search") or "-thinking" in model
|
||||
):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
tools.append({"googleSearch": {}})
|
||||
return await self._gemini_chat_completion(
|
||||
messages, model, temperature, stream, api_key, tools
|
||||
)
|
||||
# else:
|
||||
# return await self._openai_chat_completion(
|
||||
# messages, model, temperature, stream, api_key, tools
|
||||
# )
|
||||
return await self._gemini_chat_completion(request, api_key, tools)
|
||||
|
||||
async def _gemini_chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
stream: bool,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Handle Gemini API chat completion"""
|
||||
model = request.model
|
||||
messages = request.messages
|
||||
temperature = request.temperature
|
||||
stream = request.stream
|
||||
max_tokens = request.max_tokens
|
||||
stop = request.stop
|
||||
top_p = request.top_p
|
||||
top_k = request.top_k
|
||||
if model.endswith("-search"):
|
||||
gemini_model = model[:-7] # Remove -search suffix
|
||||
else:
|
||||
@@ -176,14 +218,29 @@ class ChatService:
|
||||
tools.remove({"code_execution": {}})
|
||||
payload = {
|
||||
"contents": gemini_messages,
|
||||
"generationConfig": {"temperature": temperature},
|
||||
"generationConfig": {
|
||||
"temperature": temperature,
|
||||
"maxOutputTokens": max_tokens,
|
||||
"stopSequences": stop,
|
||||
"topP": top_p,
|
||||
"topK": top_k,
|
||||
},
|
||||
"tools": tools,
|
||||
"safetySettings": [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
"threshold": "BLOCK_NONE",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@@ -195,16 +252,26 @@ class ChatService:
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
timeout = httpx.Timeout(60.0, read=60.0) # 连接超时60秒,读取超时60秒
|
||||
timeout = httpx.Timeout(
|
||||
60.0, read=60.0
|
||||
) # 连接超时60秒,读取超时60秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
|
||||
async with client.stream('POST', stream_url, json=payload) as response:
|
||||
async with client.stream(
|
||||
"POST", stream_url, json=payload
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.read()
|
||||
error_msg = error_content.decode('utf-8')
|
||||
logger.error(f"API error: {response.status_code}, {error_msg}")
|
||||
error_msg = error_content.decode("utf-8")
|
||||
logger.error(
|
||||
f"API error: {response.status_code}, {error_msg}"
|
||||
)
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
current_api_key = (
|
||||
await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
else:
|
||||
@@ -218,8 +285,13 @@ class ChatService:
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
openai_chunk = self.convert_gemini_response_to_openai(
|
||||
chunk, model, stream=True, finish_reason=None
|
||||
openai_chunk = (
|
||||
self.convert_gemini_response_to_openai(
|
||||
chunk,
|
||||
model,
|
||||
stream=True,
|
||||
finish_reason=None,
|
||||
)
|
||||
)
|
||||
if openai_chunk:
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
@@ -230,20 +302,30 @@ class ChatService:
|
||||
return
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
logger.warning(f"Read timeout occurred, attempting retry {retries + 1}")
|
||||
logger.warning(
|
||||
f"Read timeout occurred, attempting retry {retries + 1}"
|
||||
)
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
current_api_key = await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries reached. Final error: Read timeout")
|
||||
logger.error(
|
||||
f"Max retries reached. Final error: Read timeout"
|
||||
)
|
||||
yield f"data: {json.dumps({'error': 'Read timeout'})}\n\n"
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Stream error: {str(e)}, attempting retry {retries + 1}")
|
||||
logger.exception(
|
||||
f"Stream error: {str(e)}, attempting retry {retries + 1}"
|
||||
)
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
current_api_key = await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
@@ -262,9 +344,13 @@ class ChatService:
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
error_code = response.status_code
|
||||
raise Exception(f"API调用错误 - 状态码: {error_code}, 响应内容: {error_text}")
|
||||
raise Exception(
|
||||
f"API调用错误 - 状态码: {error_code}, 响应内容: {error_text}"
|
||||
)
|
||||
gemini_response = response.json()
|
||||
return self.convert_gemini_response_to_openai(gemini_response, model, finish_reason="stop")
|
||||
return self.convert_gemini_response_to_openai(
|
||||
gemini_response, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in non-stream completion")
|
||||
raise
|
||||
@@ -283,10 +369,7 @@ class ChatService:
|
||||
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n"""
|
||||
|
||||
async def generate_content(
|
||||
self,
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
api_key: str
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
) -> dict:
|
||||
"""调用Gemini API生成内容"""
|
||||
url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||
@@ -301,16 +384,15 @@ class ChatService:
|
||||
error_text = response.text
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(error_text)
|
||||
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
|
||||
raise Exception(
|
||||
f"API request failed with status {response.status_code}: {error_text}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def stream_generate_content(
|
||||
self,
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
api_key: str
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator:
|
||||
"""调用Gemini API流式生成内容"""
|
||||
retries = 0
|
||||
@@ -321,19 +403,29 @@ class ChatService:
|
||||
try:
|
||||
url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={current_api_key}"
|
||||
timeout = httpx.Timeout(60.0, read=60.0)
|
||||
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream('POST', url, json=request.model_dump()) as response:
|
||||
async with client.stream(
|
||||
"POST", url, json=request.model_dump()
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error: {response.status_code}: {error_text}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
current_api_key = (
|
||||
await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Switched to new API key: {current_api_key}"
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
|
||||
|
||||
raise Exception(
|
||||
f"API request failed with status {response.status_code}: {error_text}"
|
||||
)
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
yield line + "\n\n"
|
||||
return
|
||||
@@ -341,7 +433,9 @@ class ChatService:
|
||||
except httpx.ReadTimeout:
|
||||
logger.warning(f"Read timeout occurred, attempting retry {retries + 1}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
current_api_key = await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
@@ -350,8 +444,13 @@ class ChatService:
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
current_api_key = await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
raise
|
||||
raise
|
||||
|
||||
def create_search_link(self, web):
|
||||
return f'\n- [{web["title"]}]({web["uri"]})'
|
||||
|
||||
Reference in New Issue
Block a user