feat: 优化 Gemini 模型思考过程的展示

This commit is contained in:
yinpeng
2024-12-24 23:27:20 +08:00
parent 98ba46f779
commit 5a1c3bdbe7

View File

@@ -71,42 +71,69 @@ class ChatService:
def __init__(self, base_url: str, key_manager=None):
self.base_url = base_url
self.key_manager = key_manager
self.thinking_first = True
def convert_gemini_response_to_openai(
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None,
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None,
) -> Optional[Dict[str, Any]]:
"""Convert Gemini response to OpenAI format"""
if stream:
try:
text = ""
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = self.format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = self.format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = format_execution_result(
parts[0]["codeExecutionResult"]
)
if "thinking" in model:
if len(parts) == 1:
if self.thinking_first:
self.thinking_first = False
text = "\n🤔 **思考过程** 🤔\n---\n```\n" + parts[
0
].get("text")
else:
text = parts[0].get("text")
elif len(parts) == 2:
if self.thinking_first:
self.thinking_first = False
text = (
"\n🤔 **思考过程** 🤔\n---\n```\n"
+ parts[0].get("text")
+ "\n```\n---\n"
+ parts[1].get("text")
)
else:
text = (
parts[0].get("text")
+ "\n```\n---\n"
+ parts[1].get("text")
)
else:
text = ""
else:
text = ""
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = self.format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = self.format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = format_execution_result(
parts[0]["codeExecutionResult"]
)
else:
text = ""
text = self.add_search_link_text(model, candidate, text)
else:
text = ""
return {
"id": f"chatcmpl-{uuid.uuid4()}",
@@ -136,7 +163,9 @@ class ChatService:
"index": 0,
"message": {
"role": "assistant",
"content": response["candidates"][0]["content"]["parts"][0]["text"],
"content": response["candidates"][0]["content"]["parts"][0][
"text"
],
},
"finish_reason": finish_reason,
}
@@ -149,8 +178,17 @@ class ChatService:
}
try:
if response.get("candidates"):
text = response["candidates"][0]["content"]["parts"][0]["text"]
candidate = response["candidates"][0]
if "thinking" in model:
text = (
"\n🤔 **思考过程** 🤔\n---\n```\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n```\n---\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
text = self.add_search_link_text(model, candidate, text)
res["choices"][0]["message"]["content"] = text
return res
@@ -160,30 +198,32 @@ class ChatService:
except Exception as e:
logger.error(f"Error converting Gemini response: {str(e)}")
logger.debug(f"Raw response: {response}")
res["choices"][0]["message"]["content"] = f"Error converting Gemini response: {str(e)}"
res["choices"][0]["message"][
"content"
] = f"Error converting Gemini response: {str(e)}"
return res
def add_search_link_text(self, model, candidate, text):
if (
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
):
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
text += "\n\n---\n\n"
text += f"**【引用来源】**\n\n"
text += "**【引用来源】**\n\n"
for _, grounding_chunk in enumerate(grounding_chunks, 1):
if "web" in grounding_chunk:
text += create_search_link(grounding_chunk["web"])
return text
else:
return text
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Create chat completion using either Gemini or OpenAI API"""
model = request.model
@@ -191,7 +231,7 @@ class ChatService:
if tools is None:
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
model.endswith("-search") or "-thinking" in model
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
@@ -199,10 +239,10 @@ class ChatService:
return await self._gemini_chat_completion(request, api_key, tools)
async def _gemini_chat_completion(
self,
request: ChatRequest,
api_key: str,
tools: Optional[list] = None,
self,
request: ChatRequest,
api_key: str,
tools: Optional[list] = None,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Handle Gemini API chat completion"""
model = request.model
@@ -252,6 +292,7 @@ class ChatService:
}
if stream:
async def generate():
retries = 0
max_retries = 3
@@ -265,7 +306,7 @@ class ChatService:
async with httpx.AsyncClient(timeout=timeout) as async_client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
async with async_client.stream(
"POST", stream_url, json=payload
"POST", stream_url, json=payload
) as async_response:
if async_response.status_code != 200:
error_content = await async_response.read()
@@ -344,7 +385,9 @@ class ChatService:
return generate()
else:
try:
timeout = httpx.Timeout(300.0, read=300.0) # 连接超时300秒读取超时300秒
timeout = httpx.Timeout(
300.0, read=300.0
) # 连接超时300秒读取超时300秒
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
@@ -370,7 +413,7 @@ class ChatService:
return f"""\n【代码执行】\n```{language}\n{code}\n```\n"""
async def generate_content(
self, model_name: str, request: GeminiRequest, api_key: str
self, model_name: str, request: GeminiRequest, api_key: str
) -> dict:
"""调用Gemini API生成内容"""
url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}"
@@ -393,7 +436,7 @@ class ChatService:
raise
async def stream_generate_content(
self, model_name: str, request: GeminiRequest, api_key: str
self, model_name: str, request: GeminiRequest, api_key: str
) -> AsyncGenerator:
"""调用Gemini API流式生成内容"""
retries = 0
@@ -407,7 +450,7 @@ class ChatService:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST", url, json=request.model_dump()
"POST", url, json=request.model_dump()
) as response:
if response.status_code != 200:
error_text = await response.text()