mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-02 22:30:21 +08:00
refactor: 项目结构优化与FastAPI生命周期更新
重构项目目录结构,提高代码组织性和可维护性 将schemas目录重命名为domain,更好地表达领域模型概念 将services目录细分为service/chat、service/image等子目录 将api目录重命名为router,更符合FastAPI惯例 创建utils目录存放通用工具函数 更新FastAPI应用程序生命周期管理 替换已弃用的on_event方法为推荐的lifespan事件处理器 添加应用程序关闭时的日志记录 代码质量改进 抽取常量到constants.py,减少硬编码值 添加helpers.py提供通用工具函数 优化配置管理,使用环境变量和默认值 完善文档字符串,提高代码可读性
This commit is contained in:
306
app/handler/response_handler.py
Normal file
306
app/handler/response_handler.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# app/services/chat/response_handler.py
|
||||
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional
|
||||
import time
|
||||
import uuid
|
||||
from app.config.config import settings
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
|
||||
class ResponseHandler(ABC):
|
||||
"""响应处理器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class GeminiResponseHandler(ResponseHandler):
|
||||
"""Gemini响应处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.thinking_first = True
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
|
||||
if stream:
|
||||
return _handle_gemini_stream_response(response, model, stream)
|
||||
return _handle_gemini_normal_response(response, model, stream)
|
||||
|
||||
|
||||
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
|
||||
if not text and not tool_calls:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {"content": text, "role": "assistant"}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
class OpenAIResponseHandler(ResponseHandler):
|
||||
"""OpenAI响应处理器"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.thinking_first = True
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
finish_reason: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if stream:
|
||||
return _handle_openai_stream_response(response, model, finish_reason)
|
||||
return _handle_openai_normal_response(response, model, finish_reason)
|
||||
|
||||
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
|
||||
if stream:
|
||||
return _handle_openai_stream_image_response(image_str,model,finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str,model,finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": image_str
|
||||
},
|
||||
"finish_reason": finish_reason
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
|
||||
text, tool_calls = "", []
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
if not parts:
|
||||
return "", []
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
elif "executableCode" in parts[0]:
|
||||
text = _format_code_block(parts[0]["executableCode"])
|
||||
elif "codeExecution" in parts[0]:
|
||||
text = _format_code_block(parts[0]["codeExecution"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["executableCodeResult"]
|
||||
)
|
||||
elif "codeExecutionResult" in parts[0]:
|
||||
text = _format_execution_result(
|
||||
parts[0]["codeExecutionResult"]
|
||||
)
|
||||
elif "inlineData" in parts[0]:
|
||||
text = _extract_image_data(parts[0])
|
||||
else:
|
||||
text = ""
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(parts, gemini_format)
|
||||
else:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
if "thinking" in model:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = (
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
)
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = candidate["content"]["parts"][1]["text"]
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
text = ""
|
||||
if "parts" in candidate["content"]:
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
|
||||
else:
|
||||
text = "暂无返回"
|
||||
return text, tool_calls
|
||||
|
||||
def _extract_image_data(part: dict) -> str:
|
||||
image_uploader = None
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
base64_data = part["inlineData"]["data"]
|
||||
#将base64_data转成bytes数组
|
||||
bytes_data = base64.b64decode(base64_data)
|
||||
upload_response = image_uploader.upload(bytes_data,filename)
|
||||
if upload_response.success:
|
||||
text = f""
|
||||
else:
|
||||
text = ""
|
||||
return text
|
||||
|
||||
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
|
||||
"""提取工具调用信息"""
|
||||
if not parts or not isinstance(parts, list):
|
||||
return []
|
||||
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
|
||||
tool_calls = list()
|
||||
for i in range(len(parts)):
|
||||
part = parts[i]
|
||||
if not part or not isinstance(part, dict):
|
||||
continue
|
||||
|
||||
item = part.get("functionCall", {})
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if gemini_format:
|
||||
tool_calls.append(part)
|
||||
else:
|
||||
id = f"call_{''.join(random.sample(letters, 32))}"
|
||||
name = item.get("name", "")
|
||||
arguments = json.dumps(item.get("args", None) or {})
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"index": i,
|
||||
"id": id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": arguments},
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
|
||||
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
|
||||
def _format_code_block(code_data: dict) -> str:
|
||||
"""格式化代码块输出"""
|
||||
language = code_data.get("language", "").lower()
|
||||
code = code_data.get("code", "").strip()
|
||||
return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n"""
|
||||
|
||||
|
||||
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
|
||||
if (
|
||||
settings.SHOW_SEARCH_LINK
|
||||
and model.endswith("-search")
|
||||
and "groundingMetadata" in candidate
|
||||
and "groundingChunks" in candidate["groundingMetadata"]
|
||||
):
|
||||
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
||||
text += "\n\n---\n\n"
|
||||
text += "**【引用来源】**\n\n"
|
||||
for _, grounding_chunk in enumerate(grounding_chunks, 1):
|
||||
if "web" in grounding_chunk:
|
||||
text += _create_search_link(grounding_chunk["web"])
|
||||
return text
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _create_search_link(grounding_chunk: dict) -> str:
|
||||
return f'\n- [{grounding_chunk["title"]}]({grounding_chunk["uri"]})'
|
||||
|
||||
|
||||
def _format_execution_result(result_data: dict) -> str:
|
||||
"""格式化执行结果输出"""
|
||||
outcome = result_data.get("outcome", "")
|
||||
output = result_data.get("output", "").strip()
|
||||
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n"""
|
||||
Reference in New Issue
Block a user