mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-11 10:39:50 +08:00
feat: 优化 Gemini 模型思考过程的展示
This commit is contained in:
@@ -71,42 +71,69 @@ class ChatService:
|
||||
def __init__(self, base_url: str, key_manager=None):
|
||||
self.base_url = base_url
|
||||
self.key_manager = key_manager
|
||||
self.thinking_first = True
|
||||
|
||||
def convert_gemini_response_to_openai(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Convert Gemini response to OpenAI format"""
|
||||
if stream:
|
||||
try:
|
||||
text = ""
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
elif "executableCode" in parts[0]:
|
||||
text = self.format_code_block(parts[0]["executableCode"])
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = self.format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
if "thinking" in model:
|
||||
if len(parts) == 1:
|
||||
if self.thinking_first:
|
||||
self.thinking_first = False
|
||||
text = "\n🤔 **思考过程** 🤔\n---\n```\n" + parts[
|
||||
0
|
||||
].get("text")
|
||||
else:
|
||||
text = parts[0].get("text")
|
||||
elif len(parts) == 2:
|
||||
if self.thinking_first:
|
||||
self.thinking_first = False
|
||||
text = (
|
||||
"\n🤔 **思考过程** 🤔\n---\n```\n"
|
||||
+ parts[0].get("text")
|
||||
+ "\n```\n---\n"
|
||||
+ parts[1].get("text")
|
||||
)
|
||||
else:
|
||||
text = (
|
||||
parts[0].get("text")
|
||||
+ "\n```\n---\n"
|
||||
+ parts[1].get("text")
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
else:
|
||||
text = ""
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
elif "executableCode" in parts[0]:
|
||||
text = self.format_code_block(parts[0]["executableCode"])
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = self.format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
text = self.add_search_link_text(model, candidate, text)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
@@ -136,7 +163,9 @@ class ChatService:
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0]["text"],
|
||||
"content": response["candidates"][0]["content"]["parts"][0][
|
||||
"text"
|
||||
],
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
@@ -149,8 +178,17 @@ class ChatService:
|
||||
}
|
||||
try:
|
||||
if response.get("candidates"):
|
||||
text = response["candidates"][0]["content"]["parts"][0]["text"]
|
||||
candidate = response["candidates"][0]
|
||||
if "thinking" in model:
|
||||
text = (
|
||||
"\n🤔 **思考过程** 🤔\n---\n```\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n```\n---\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
)
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
|
||||
text = self.add_search_link_text(model, candidate, text)
|
||||
res["choices"][0]["message"]["content"] = text
|
||||
return res
|
||||
@@ -160,30 +198,32 @@ class ChatService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting Gemini response: {str(e)}")
|
||||
logger.debug(f"Raw response: {response}")
|
||||
res["choices"][0]["message"]["content"] = f"Error converting Gemini response: {str(e)}"
|
||||
res["choices"][0]["message"][
|
||||
"content"
|
||||
] = f"Error converting Gemini response: {str(e)}"
|
||||
return res
|
||||
|
||||
def add_search_link_text(self, model, candidate, text):
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
text += f"**【引用来源】**\n\n"
|
||||
text += "**【引用来源】**\n\n"
|
||||
for _, grounding_chunk in enumerate(grounding_chunks, 1):
|
||||
if "web" in grounding_chunk:
|
||||
text += create_search_link(grounding_chunk["web"])
|
||||
return text
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Create chat completion using either Gemini or OpenAI API"""
|
||||
model = request.model
|
||||
@@ -191,7 +231,7 @@ class ChatService:
|
||||
if tools is None:
|
||||
tools = []
|
||||
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
|
||||
model.endswith("-search") or "-thinking" in model
|
||||
model.endswith("-search") or "-thinking" in model
|
||||
):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
@@ -199,10 +239,10 @@ class ChatService:
|
||||
return await self._gemini_chat_completion(request, api_key, tools)
|
||||
|
||||
async def _gemini_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Handle Gemini API chat completion"""
|
||||
model = request.model
|
||||
@@ -252,6 +292,7 @@ class ChatService:
|
||||
}
|
||||
|
||||
if stream:
|
||||
|
||||
async def generate():
|
||||
retries = 0
|
||||
max_retries = 3
|
||||
@@ -265,7 +306,7 @@ class ChatService:
|
||||
async with httpx.AsyncClient(timeout=timeout) as async_client:
|
||||
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
|
||||
async with async_client.stream(
|
||||
"POST", stream_url, json=payload
|
||||
"POST", stream_url, json=payload
|
||||
) as async_response:
|
||||
if async_response.status_code != 200:
|
||||
error_content = await async_response.read()
|
||||
@@ -344,7 +385,9 @@ class ChatService:
|
||||
return generate()
|
||||
else:
|
||||
try:
|
||||
timeout = httpx.Timeout(300.0, read=300.0) # 连接超时300秒,读取超时300秒
|
||||
timeout = httpx.Timeout(
|
||||
300.0, read=300.0
|
||||
) # 连接超时300秒,读取超时300秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
@@ -370,7 +413,7 @@ class ChatService:
|
||||
return f"""\n【代码执行】\n```{language}\n{code}\n```\n"""
|
||||
|
||||
async def generate_content(
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
) -> dict:
|
||||
"""调用Gemini API生成内容"""
|
||||
url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||
@@ -393,7 +436,7 @@ class ChatService:
|
||||
raise
|
||||
|
||||
async def stream_generate_content(
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
self, model_name: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator:
|
||||
"""调用Gemini API流式生成内容"""
|
||||
retries = 0
|
||||
@@ -407,7 +450,7 @@ class ChatService:
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST", url, json=request.model_dump()
|
||||
"POST", url, json=request.model_dump()
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = await response.text()
|
||||
|
||||
Reference in New Issue
Block a user