refactor: 迁移媒体常量并重构相关处理逻辑

将音频/视频相关的配置(支持格式、大小限制、MIME类型)从 `config.py` 移动到 `core/constants.py`,以集中管理常量。

更新 `message_converter.py`:
- 从 `core.constants` 导入媒体常量。
- 添加并使用 `message_converter` 的专用日志记录器。
- 清理导入和代码格式。

更新 `openai_chat_service.py`:
- 调整 `_has_media_parts` 函数以正确检测 `inline_data`。
- 清理导入和代码格式。

在 `log/logger.py` 中添加 `get_message_converter_logger` 函数。

对 `config.py` 和 `response_handler.py` 进行了相关的移除和微小的代码清理。
This commit is contained in:
snaily
2025-04-29 17:54:48 +08:00
parent e822831178
commit e9d19de7c6
6 changed files with 389 additions and 298 deletions

View File

@@ -69,12 +69,6 @@ class Settings(BaseSettings):
# 日志配置
LOG_LEVEL: str = "INFO" # 默认日志级别
# Audio/Video Settings
SUPPORTED_AUDIO_FORMATS: List[str] = ["wav", "mp3", "flac", "ogg"] # Add formats Gemini supports
SUPPORTED_VIDEO_FORMATS: List[str] = ["mp4", "mov", "avi", "webm"] # Add formats Gemini supports
MAX_AUDIO_SIZE_BYTES: int = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
MAX_VIDEO_SIZE_BYTES: int = 200 * 1024 * 1024 # Example: 200MB limit
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 设置默认AUTH_TOKEN如果未提供
@@ -84,24 +78,6 @@ class Settings(BaseSettings):
# 创建全局配置实例
settings = Settings()
# Optional: Define MIME type mappings if needed, or handle directly in converter
AUDIO_FORMAT_TO_MIMETYPE = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"flac": "audio/flac",
"ogg": "audio/ogg",
# Add other mappings supported by Gemini
}
VIDEO_FORMAT_TO_MIMETYPE = {
"mp4": "video/mp4",
"mov": "video/quicktime",
"avi": "video/x-msvideo",
"webm": "video/webm",
# Add other mappings supported by Gemini
}
def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
"""尝试将数据库字符串值解析为目标 Python 类型"""
from app.log.logger import get_config_logger # 函数内导入

View File

@@ -40,3 +40,24 @@ DEFAULT_STREAM_CHUNK_SIZE = 5
# 正则表达式模式
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
# Audio/Video Settings
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
# Optional: Define MIME type mappings if needed, or handle directly in converter
AUDIO_FORMAT_TO_MIMETYPE = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"flac": "audio/flac",
"ogg": "audio/ogg",
}
VIDEO_FORMAT_TO_MIMETYPE = {
"mp4": "video/mp4",
"mov": "video/quicktime",
"avi": "video/x-msvideo",
"webm": "video/webm",
}

View File

@@ -1,66 +1,70 @@
from abc import ABC, abstractmethod
import base64
import json
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import requests
import base64
import logging # Add logging
# Import settings and mappings
from app.config.config import settings, AUDIO_FORMAT_TO_MIMETYPE, VIDEO_FORMAT_TO_MIMETYPE
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
from app.core.constants import (
AUDIO_FORMAT_TO_MIMETYPE,
DATA_URL_PATTERN,
IMAGE_URL_PATTERN,
MAX_AUDIO_SIZE_BYTES,
MAX_VIDEO_SIZE_BYTES,
SUPPORTED_AUDIO_FORMATS,
SUPPORTED_ROLES,
SUPPORTED_VIDEO_FORMATS,
VIDEO_FORMAT_TO_MIMETYPE,
)
from app.log.logger import get_message_converter_logger
logger = logging.getLogger(__name__) # Add a logger
logger = get_message_converter_logger()
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
def convert(
self, messages: List[Dict[str, Any]]
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
def _get_mime_type_and_data(base64_string):
"""
从 base64 字符串中提取 MIME 类型和数据。
参数:
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
返回:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
if base64_string.startswith("data:"):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
mime_type = (
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
mime_type, encoded_data = _get_mime_type_and_data(image_url)
return {
"inline_data": {
"mime_type": mime_type,
"data": encoded_data
}
}
return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
else:
encoded_data = _convert_image_to_base64(image_url)
return {
"inline_data": {
"mime_type": "image/png",
"data": encoded_data
}
}
return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
def _convert_image_to_base64(url: str) -> str:
@@ -74,7 +78,7 @@ def _convert_image_to_base64(url: str) -> str:
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
img_data = base64.b64encode(response.content).decode("utf-8")
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
@@ -98,12 +102,9 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append({
"inlineData": {
"mimeType": "image/png",
"data": base64_data
}
})
parts.append(
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
)
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
@@ -116,20 +117,30 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def _validate_media_data(self, format: str, data: str, supported_formats: List[str], max_size: int) -> tuple[Optional[str], Optional[str]]:
def _validate_media_data(
self, format: str, data: str, supported_formats: List[str], max_size: int
) -> tuple[Optional[str], Optional[str]]:
"""Validates format and size of Base64 media data."""
if format.lower() not in supported_formats:
logger.error(f"Unsupported media format: {format}. Supported: {supported_formats}")
logger.error(
f"Unsupported media format: {format}. Supported: {supported_formats}"
)
raise ValueError(f"Unsupported media format: {format}")
try:
# Decode Base64 to check size
# Be careful with memory usage for very large files
# Consider streaming decoding or checking length heuristic first if memory is a concern
decoded_data = base64.b64decode(data, validate=True) # Use validate=True for stricter check
decoded_data = base64.b64decode(
data, validate=True
) # Use validate=True for stricter check
if len(decoded_data) > max_size:
logger.error(f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes).")
raise ValueError(f"Media data size exceeds limit of {max_size // 1024 // 1024}MB")
logger.error(
f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
)
raise ValueError(
f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
)
# No need to return decoded_data, just the original base64 if valid
return data
except base64.binascii.Error as e:
@@ -137,9 +148,11 @@ class OpenAIMessageConverter(MessageConverter):
raise ValueError("Invalid Base64 data")
except Exception as e:
logger.error(f"Error validating media data: {e}")
raise # Re-raise other potential errors
raise
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
def convert(
self, messages: List[Dict[str, Any]]
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction_parts = []
@@ -147,33 +160,48 @@ class OpenAIMessageConverter(MessageConverter):
role = msg.get("role", "")
parts = []
# --- Start Modification ---
if "content" in msg and isinstance(msg["content"], list):
for content_item in msg["content"]:
if not isinstance(content_item, dict):
# Skip non-dict items if any unexpected format appears
logger.warning(f"Skipping unexpected content item format: {type(content_item)}")
logger.warning(
f"Skipping unexpected content item format: {type(content_item)}"
)
continue
content_type = content_item.get("type")
if content_type == "text" and content_item.get("text"):
parts.append({"text": content_item["text"]})
elif content_type == "image_url" and content_item.get("image_url", {}).get("url"):
elif content_type == "image_url" and content_item.get(
"image_url", {}
).get("url"):
try:
parts.append(_convert_image(content_item["image_url"]["url"]))
parts.append(
_convert_image(content_item["image_url"]["url"])
)
except Exception as e:
logger.error(f"Failed to convert image URL {content_item['image_url']['url']}: {e}")
# Decide how to handle: skip part, add error text, etc.
parts.append({"text": f"[Error processing image: {content_item['image_url']['url']}]"})
logger.error(
f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
)
# Decide how to handle: skip part, add error text, etc.
parts.append(
{
"text": f"[Error processing image: {content_item['image_url']['url']}]"
}
)
# --- Add handling for input_audio ---
elif content_type == "input_audio" and content_item.get("input_audio"):
elif content_type == "input_audio" and content_item.get(
"input_audio"
):
audio_info = content_item["input_audio"]
audio_data = audio_info.get("data")
audio_format = audio_info.get("format", "").lower()
if not audio_data or not audio_format:
logger.warning("Skipping audio part due to missing data or format.")
logger.warning(
"Skipping audio part due to missing data or format."
)
continue
try:
@@ -181,148 +209,151 @@ class OpenAIMessageConverter(MessageConverter):
validated_data = self._validate_media_data(
audio_format,
audio_data,
settings.SUPPORTED_AUDIO_FORMATS,
settings.MAX_AUDIO_SIZE_BYTES
SUPPORTED_AUDIO_FORMATS,
MAX_AUDIO_SIZE_BYTES,
)
# Get MIME type
mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
if not mime_type:
# Should not happen if format validation passed, but double-check
logger.error(f"Could not find MIME type for supported format: {audio_format}")
raise ValueError(f"Internal error: MIME type mapping missing for {audio_format}")
logger.error(
f"Could not find MIME type for supported format: {audio_format}"
)
raise ValueError(
f"Internal error: MIME type mapping missing for {audio_format}"
)
parts.append({
"inlineData": {
"mimeType": mime_type,
"data": validated_data # Use the validated Base64 data
parts.append(
{
"inline_data": {
"mimeType": mime_type,
"data": validated_data, # Use the validated Base64 data
}
}
})
logger.debug(f"Successfully added audio part (format: {audio_format})")
)
logger.debug(
f"Successfully added audio part (format: {audio_format})"
)
except ValueError as e:
logger.error(f"Skipping audio part due to validation error: {e}")
# Add placeholder text indicating the error
logger.error(
f"Skipping audio part due to validation error: {e}"
)
parts.append({"text": f"[Error processing audio: {e}]"})
except Exception as e:
logger.exception(f"Unexpected error processing audio part.")
parts.append({"text": "[Unexpected error processing audio]"})
except Exception:
logger.exception("Unexpected error processing audio part.")
parts.append(
{"text": "[Unexpected error processing audio]"}
)
# --- Add handling for input_video (similar pattern) ---
elif content_type == "input_video" and content_item.get("input_video"):
elif content_type == "input_video" and content_item.get(
"input_video"
):
video_info = content_item["input_video"]
video_data = video_info.get("data")
video_format = video_info.get("format", "").lower()
if not video_data or not video_format:
logger.warning("Skipping video part due to missing data or format.")
logger.warning(
"Skipping video part due to missing data or format."
)
continue
try:
validated_data = self._validate_media_data(
video_format,
video_data,
settings.SUPPORTED_VIDEO_FORMATS,
settings.MAX_VIDEO_SIZE_BYTES
SUPPORTED_VIDEO_FORMATS,
MAX_VIDEO_SIZE_BYTES,
)
mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
if not mime_type:
raise ValueError(f"Internal error: MIME type mapping missing for {video_format}")
raise ValueError(
f"Internal error: MIME type mapping missing for {video_format}"
)
parts.append({
"inlineData": {
"mimeType": mime_type,
"data": validated_data
parts.append(
{
"inline_data": {
"mimeType": mime_type,
"data": validated_data,
}
}
})
logger.debug(f"Successfully added video part (format: {video_format})")
)
logger.debug(
f"Successfully added video part (format: {video_format})"
)
except ValueError as e:
logger.error(f"Skipping video part due to validation error: {e}")
parts.append({"text": f"[Error processing video: {e}]"})
except Exception as e:
logger.exception(f"Unexpected error processing video part.")
parts.append({"text": "[Unexpected error processing video]"})
# --- End new media handling ---
logger.error(
f"Skipping video part due to validation error: {e}"
)
parts.append({"text": f"[Error processing video: {e}]"})
except Exception:
logger.exception("Unexpected error processing video part.")
parts.append(
{"text": "[Unexpected error processing video]"}
)
else:
# Log unrecognized but present types
if content_type:
logger.warning(f"Unsupported content type or missing data in structured content: {content_type}")
# Silently ignore items without a 'type' or if structure is unexpected
logger.warning(
f"Unsupported content type or missing data in structured content: {content_type}"
)
# --- End Modification for list content ---
# Keep processing for simple string content (might contain image markdown)
elif "content" in msg and isinstance(msg["content"], str) and msg["content"]:
# This path handles simple text or markdown images.
# If you expect audio/video ONLY via the structured list format,
# this part remains as is. If you might have URLs in plain text,
# you'd need more complex regex parsing here.
elif (
"content" in msg and isinstance(msg["content"], str) and msg["content"]
):
parts.extend(_process_text_with_image(msg["content"]))
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
# Keep existing tool call processing
for tool_call in msg["tool_calls"]:
function_call = tool_call.get("function",{})
# Sanitize arguments loading
arguments_str = function_call.get("arguments","{}")
try:
function_call["args"] = json.loads(arguments_str)
except json.JSONDecodeError:
logger.warning(f"Failed to decode tool call arguments: {arguments_str}")
function_call["args"] = {} # Assign empty dict on error
if "arguments" in function_call: # Check before deleting
# Ensure 'arguments' key exists before attempting deletion
# In some OpenAI versions, it might already be absent
pass # No explicit delete needed if structure is {'function': {'name': '...', 'args': ...}}
else:
# If 'arguments' was the source key, delete it after parsing
if 'arguments' in function_call: # Check again just in case
# Keep existing tool call processing
for tool_call in msg["tool_calls"]:
function_call = tool_call.get("function", {})
# Sanitize arguments loading
arguments_str = function_call.get("arguments", "{}")
try:
function_call["args"] = json.loads(arguments_str)
except json.JSONDecodeError:
logger.warning(
f"Failed to decode tool call arguments: {arguments_str}"
)
function_call["args"] = {}
if "arguments" in function_call:
if "arguments" in function_call:
del function_call["arguments"]
parts.append({"functionCall": function_call})
parts.append({"functionCall": function_call})
# Role assignment and message appending logic (keep as is)
if role not in SUPPORTED_ROLES:
if role == "tool":
role = "user" # Gemini uses 'user' role for function/tool responses
# ... (rest of role handling logic) ...
role = "user"
else:
# Fallback role logic
if idx == len(messages) - 1:
role = "user"
else:
# Previous logic assigned 'model'. Check if this is always correct.
# Tool/Function responses are usually 'model' in Gemini after the 'user' (tool result) turn.
role = "model" # Stick to 'model' as the default fallback for non-user/system/tool
# 如果是最后一条消息,则认为是用户消息
if idx == len(messages) - 1:
role = "user"
else:
role = "model"
if parts:
if role == "system":
# Check if system instructions can contain media - unlikely based on Gemini docs
# Filter out non-text parts for safety?
text_only_parts = [p for p in parts if "text" in p]
if len(text_only_parts) != len(parts):
logger.warning("Non-text parts found in system message; discarding them.")
if text_only_parts:
if role == "system":
text_only_parts = [p for p in parts if "text" in p]
if len(text_only_parts) != len(parts):
logger.warning(
"Non-text parts found in system message; discarding them."
)
if text_only_parts:
system_instruction_parts.extend(text_only_parts)
else:
# Ensure role is mapped correctly ('model' for assistant turns, 'user' for tool result turns)
gemini_role = "model" if role == "assistant" else role # 'tool' role already mapped to 'user'
converted_messages.append({"role": gemini_role, "parts": parts})
else:
converted_messages.append({"role": role, "parts": parts})
system_instruction = (
None
if not system_instruction_parts
else {
"role": "system", # Gemini supports a dedicated system instruction
"role": "system",
"parts": system_instruction_parts,
}
)
# Gemini expects 'model' for assistant turns, and 'user' for function/tool responses.
# The role mapping logic above should handle this correctly now.
# Debug: Log the final converted structure before returning
# logger.debug(f"Converted messages for Gemini: {json.dumps(converted_messages, indent=2)}")
# if system_instruction:
# logger.debug(f"System instruction for Gemini: {json.dumps(system_instruction, indent=2)}")
return converted_messages, system_instruction

View File

@@ -1,12 +1,12 @@
import base64
import json
import random
import string
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
@@ -15,7 +15,9 @@ class ResponseHandler(ABC):
"""响应处理器基类"""
@abstractmethod
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
def handle_response(
self, response: Dict[str, Any], model: str, stream: bool = False
) -> Dict[str, Any]:
pass
@@ -26,14 +28,20 @@ class GeminiResponseHandler(ResponseHandler):
self.thinking_first = True
self.thinking_status = False
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
def handle_response(
self, response: Dict[str, Any], model: str, stream: bool = False
) -> Dict[str, Any]:
if stream:
return _handle_gemini_stream_response(response, model, stream)
return _handle_gemini_normal_response(response, model, stream)
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
def _handle_openai_stream_response(
response: Dict[str, Any], model: str, finish_reason: str
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=True, gemini_format=False
)
if not text and not tool_calls:
delta = {}
else:
@@ -50,8 +58,12 @@ def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_
}
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
def _handle_openai_normal_response(
response: Dict[str, Any], model: str, finish_reason: str
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=False, gemini_format=False
)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
@@ -60,7 +72,11 @@ def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
"message": {
"role": "assistant",
"content": text,
"tool_calls": tool_calls,
},
"finish_reason": finish_reason,
}
],
@@ -77,59 +93,67 @@ class OpenAIResponseHandler(ResponseHandler):
self.thinking_status = False
def handle_response(
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None,
) -> Optional[Dict[str, Any]]:
if stream:
return _handle_openai_stream_response(response, model, finish_reason)
return _handle_openai_normal_response(response, model, finish_reason)
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
def handle_image_chat_response(
self, image_str: str, model: str, stream=False, finish_reason="stop"
):
if stream:
return _handle_openai_stream_image_response(image_str,model,finish_reason)
return _handle_openai_normal_image_response(image_str,model,finish_reason)
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
return _handle_openai_stream_image_response(image_str, model, finish_reason)
return _handle_openai_normal_image_response(image_str, model, finish_reason)
def _handle_openai_stream_image_response(
image_str: str, model: str, finish_reason: str
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": image_str} if image_str else {},
"finish_reason": finish_reason
}]
"choices": [
{
"index": 0,
"delta": {"content": image_str} if image_str else {},
"finish_reason": finish_reason,
}
],
}
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
def _handle_openai_normal_image_response(
image_str: str, model: str, finish_reason: str
) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": image_str
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": image_str},
"finish_reason": finish_reason,
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
def _extract_result(
response: Dict[str, Any],
model: str,
stream: bool = False,
gemini_format: bool = False,
) -> tuple[str, List[Dict[str, Any]]]:
text, tool_calls = "", []
if stream:
if response.get("candidates"):
@@ -145,14 +169,10 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
elif "codeExecution" in parts[0]:
text = _format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = _format_execution_result(
parts[0]["executableCodeResult"]
)
text = _format_execution_result(parts[0]["executableCodeResult"])
elif "codeExecutionResult" in parts[0]:
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
elif "inlineData" in parts[0]:
text = _format_execution_result(parts[0]["codeExecutionResult"])
elif "inline_data" in parts[0]:
text = _extract_image_data(parts[0])
else:
text = ""
@@ -165,10 +185,10 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
@@ -183,37 +203,50 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
for part in candidate["content"]["parts"]:
if "text" in part:
text += part["text"]
elif "inlineData" in part:
elif "inline_data" in part:
text += _extract_image_data(part)
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
tool_calls = _extract_tool_calls(
candidate["content"]["parts"], gemini_format
)
else:
text = "暂无返回"
return text, tool_calls
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
base64_data = part["inline_data"]["data"]
# 将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
upload_response = image_uploader.upload(bytes_data, filename)
if upload_response.success:
text = f"\n\n![image]({upload_response.data.url})\n\n"
else:
text = ""
return text
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
def _extract_tool_calls(
parts: List[Dict[str, Any]], gemini_format: bool
) -> List[Dict[str, Any]]:
"""提取工具调用信息"""
if not parts or not isinstance(parts, list):
return []
@@ -249,8 +282,12 @@ def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> Lis
return tool_calls
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
def _handle_gemini_stream_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=stream, gemini_format=True
)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
@@ -259,8 +296,12 @@ def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream:
return response
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
def _handle_gemini_normal_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
text, tool_calls = _extract_result(
response, model, stream=stream, gemini_format=True
)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
@@ -278,10 +319,10 @@ def _format_code_block(code_data: dict) -> str:
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
if (
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
):
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
text += "\n\n---\n\n"

View File

@@ -206,3 +206,6 @@ def get_update_logger():
def get_scheduler_routes():
return Logger.setup_logger("scheduler_routes")
def get_message_converter_logger():
return Logger.setup_logger("message_converter")

View File

@@ -1,13 +1,17 @@
# app/services/chat_service.py
import datetime
import json
import re
import datetime # Add datetime import
import time # Add time import
import time
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from app.config.config import settings
from app.database.services import (
add_error_log,
add_request_log,
)
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.handler.message_converter import OpenAIMessageConverter
from app.handler.response_handler import OpenAIResponseHandler
@@ -16,21 +20,17 @@ from app.log.logger import get_openai_logger
from app.service.client.api_client import GeminiApiClient
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager
from app.database.services import add_error_log, add_request_log # Import add_request_log
logger = get_openai_logger()
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片、音频或视频部分 (inlineData)"""
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
for content in contents:
if content and "parts" in content and isinstance(content["parts"], list):
for part in content["parts"]:
# Check if the part is a dictionary and contains 'inlineData'
if isinstance(part, dict) and "inlineData" in part:
# Optionally, could check part["inlineData"].get("mimeType") prefix
if isinstance(part, dict) and "inline_data" in part:
return True
# Add checks here if Gemini uses other keys for media (e.g., 'fileData')
return False
@@ -49,12 +49,12 @@ def _build_tools(
or model.endswith("-image")
or model.endswith("-image-generation")
)
and not _has_media_parts(messages) # Use the updated check
and not _has_media_parts(messages) # Use the updated check
):
tool["codeExecution"] = {}
logger.debug("Code execution tool enabled.")
elif _has_media_parts(messages):
logger.debug("Code execution tool disabled due to media parts presence.")
logger.debug("Code execution tool disabled due to media parts presence.")
if model.endswith("-search"):
tool["googleSearch"] = {}
@@ -69,7 +69,9 @@ def _build_tools(
if item.get("type", "") == "function" and item.get("function"):
function = deepcopy(item.get("function"))
parameters = function.get("parameters", {})
if parameters.get("type") == "object" and not parameters.get("properties", {}):
if parameters.get("type") == "object" and not parameters.get(
"properties", {}
):
function.pop("parameters", None)
function_declarations.append(function)
@@ -138,9 +140,11 @@ def _build_payload(
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if request.model.endswith("-non-thinking"):
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
if request.model in settings.THINKING_BUDGET_MAP:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model,1000)}
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
}
if (
instruction
@@ -212,7 +216,7 @@ class OpenAIChatService:
try:
response = await self.api_client.generate_content(payload, model, api_key)
is_success = True
status_code = 200 # Assume 200 on success
status_code = 200
return self.response_handler.handle_response(
response, model, stream=False, finish_reason="stop"
)
@@ -225,17 +229,17 @@ class OpenAIChatService:
if match:
status_code = int(match.group(1))
else:
status_code = 500 # Default if parsing fails
status_code = 500
await add_error_log(
gemini_key=api_key, # Note: Parameter name is gemini_key in add_error_log
gemini_key=api_key,
model_name=model,
error_type="openai-chat-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
request_msg=payload,
)
raise e # Re-throw exception
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
@@ -245,7 +249,7 @@ class OpenAIChatService:
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
request_time=request_datetime,
)
async def _handle_stream_completion(
@@ -300,7 +304,7 @@ class OpenAIChatService:
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
is_success = True
status_code = 200 # Assume 200 on success
status_code = 200
break # 成功后退出循环
except Exception as e:
retries += 1
@@ -314,7 +318,7 @@ class OpenAIChatService:
if match:
status_code = int(match.group(1))
else:
status_code = 500 # Default if parsing fails
status_code = 500
# Log error to error log table
await add_error_log(
@@ -323,38 +327,40 @@ class OpenAIChatService:
error_type="openai-chat-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
request_msg=payload,
)
# Attempt to switch API Key
# Ensure key_manager is available (might need adjustment if not always passed)
if self.key_manager:
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
api_key = await self.key_manager.handle_api_failure(
current_attempt_key, retries
)
if api_key:
logger.info(f"Switched to new API key: {api_key}")
else:
logger.error(f"No valid API key available after {retries} retries.")
break # Exit loop if no key available
logger.error(
f"No valid API key available after {retries} retries."
)
break
else:
logger.error("KeyManager not available for retry logic.")
break # Exit loop if key manager is missing
logger.error("KeyManager not available for retry logic.")
break
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming."
)
break # Exit loop after max retries
logger.error(f"Max retries ({max_retries}) reached for streaming.")
break
finally:
# Log the final outcome of the streaming request
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=final_api_key, # Log the last key used
is_success=is_success, # Log the final success status
status_code=status_code, # Log the last known status code
latency_ms=latency_ms, # Log total time including retries
request_time=request_datetime
api_key=final_api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
# If the loop finished due to failure, yield error and DONE
if not is_success and retries >= max_retries:
@@ -362,9 +368,7 @@ class OpenAIChatService:
yield "data: [DONE]\n\n"
async def create_image_chat_completion(
self,
request: ChatRequest,
api_key: str
self, request: ChatRequest, api_key: str
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
@@ -374,18 +378,22 @@ class OpenAIChatService:
)
if request.stream:
return self._handle_stream_image_completion(request.model, image_res, api_key)
return self._handle_stream_image_completion(
request.model, image_res, api_key
)
else:
return await self._handle_normal_image_completion(request.model, image_res, api_key)
return await self._handle_normal_image_completion(
request.model, image_res, api_key
)
async def _handle_stream_image_completion(
self, model: str, image_data: str, api_key:str
self, model: str, image_data: str, api_key: str
) -> AsyncGenerator[str, None]:
logger.info(f"Starting stream image completion for model: {model}")
start_time = time.perf_counter()
request_datetime = datetime.datetime.now() # Although not used for DB log here
request_datetime = datetime.datetime.now()
is_success = False
status_code = None # Although not used for DB log here
status_code = None
try:
if image_data:
@@ -409,7 +417,9 @@ class OpenAIChatService:
# 如果没有文本内容如图片URL等整块输出
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
logger.info(f"Stream image completion finished successfully for model: {model}")
logger.info(
f"Stream image completion finished successfully for model: {model}"
)
is_success = True
status_code = 200
yield "data: [DONE]\n\n"
@@ -417,46 +427,51 @@ class OpenAIChatService:
is_success = False
error_log_msg = f"Stream image completion failed for model {model}: {e}"
logger.error(error_log_msg)
status_code = 500 # Default error code
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-stream", # Specific error type
error_type="openai-image-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
request_msg={
"image_data_truncated": image_data[:1000]
},
)
yield f"data: {json.dumps({'error': error_log_msg})}\n\n" # Send error to client
yield "data: [DONE]\n\n" # Still need DONE message
# Re-raising might break the stream, decide if needed
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
yield "data: [DONE]\n\n"
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
logger.info(f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}")
logger.info(
f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
request_time=request_datetime,
)
async def _handle_normal_image_completion(
self, model: str, image_data: str, api_key: str # Add api_key parameter
self, model: str, image_data: str, api_key: str
) -> Dict[str, Any]:
logger.info(f"Starting normal image completion for model: {model}")
start_time = time.perf_counter()
request_datetime = datetime.datetime.now() # Although not used for DB log here
request_datetime = datetime.datetime.now()
is_success = False
status_code = None # Although not used for DB log here
status_code = None
result = None
try:
result = self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)
logger.info(f"Normal image completion finished successfully for model: {model}")
logger.info(
f"Normal image completion finished successfully for model: {model}"
)
is_success = True
status_code = 200
return result
@@ -464,26 +479,30 @@ class OpenAIChatService:
is_success = False
error_log_msg = f"Normal image completion failed for model {model}: {e}"
logger.error(error_log_msg)
status_code = 500 # Default error code
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-non-stream", # Specific error type
error_type="openai-image-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]} # Log truncated data
request_msg={
"image_data_truncated": image_data[:1000]
},
)
# Re-raise the exception so the caller knows about the failure
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
logger.info(f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}")
logger.info(
f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
request_time=request_datetime,
)