Compare commits

...

11 Commits

Author SHA1 Message Date
snaily
8779a5f0b3 feat: 添加对 image-generation 模型的支持
在 gemini_chat_service 和 openai_chat_service 中添加对 "-image-generation" 后缀模型的支持
确保 image-generation 模型与 image 模型有相同的处理逻辑
2025-03-16 23:53:53 +08:00
cr-zhichen
89f2825ac7 feat: 新增对CloudFlare ImgBed的支持,更新环境变量和文档 2025-03-16 04:39:40 +00:00
snaily
985a12554d fix:修改OpenAI消息转换器中assistant消息处理逻辑,将特殊处理的目标从最后一条消息调整为倒数第二条消息。 2025-03-15 21:18:20 +08:00
snaily
37a7a140fc feat:改进消息转换器中的图像处理和消息分割逻辑
添加 _get_mime_type_and_data 函数从 base64 字符串中提取 MIME 类型和数据
修改 _convert_image 函数使用动态检测的 MIME 类型,而非硬编码
将 _process_text_with_image 中的 MIME 类型从 image/jpeg 改为 image/png
简化异常处理逻辑
优化 OpenAIMessageConverter 中的消息分割逻辑,仅对最后一个 assistant 消息进行分割处理
2025-03-15 21:11:10 +08:00
zhanghaoyu
28e67cc3fa 1. modify IMAGE_URL_PATTERN
2. modify import
2025-03-15 12:37:56 +08:00
zhanghaoyu7
d99a0bde93 feat: 新增图文上下文同步 2025-03-14 16:29:03 +08:00
snaily
cb5cd92041 fix: 修正Dockerfile中TOOLS_CODE_EXECUTION_ENABLED环境变量的拼写错误
将TOOLS_CODE_EXECUTION_ENABLED环境变量的值从"fasle"更正为"false",修复了拼写错误。
2025-03-14 13:46:31 +08:00
snaily
0be85e9536 feat(gemini_routes): 添加deepcopy导入
在gemini_routes.py中添加了Python标准库copy模块中的deepcopy函数导入,用于创建对象的深度副本,确保数据操作过程中不会意外修改原始对象。
2025-03-14 13:43:17 +08:00
Toddy
632dee38b3 check model before send request 2025-03-14 04:11:21 +00:00
Toddy
16c28bf1ba combine multiple system instructions into one 2025-03-14 02:55:29 +00:00
snaily
71af1db330 feat: 添加Gemini图像生成与处理功能
主要更新:

添加图像模型支持

新增MODEL_IMAGE配置项
在模型列表中添加gemini-2.0-flash-exp-image模型
修改ModelService以支持图像模型
增强图像处理能力

添加PicGoUploader类用于图像上传
实现图像响应处理逻辑(_extract_image_data)
支持base64图像数据的解码与上传
优化请求与响应处理

为图像模型添加特殊处理逻辑
修改API客户端以支持图像模型
更新GeminiRequest默认值
安全性调整

将TOOLS_CODE_EXECUTION_ENABLED默认设置为false
2025-03-14 00:27:23 +08:00
15 changed files with 521 additions and 106 deletions

View File

@@ -2,7 +2,8 @@ API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
TOOLS_CODE_EXECUTION_ENABLED=true
MODEL_IMAGE=["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED=false
SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true
BASE_URL=https://generativelanguage.googleapis.com/v1beta
@@ -12,6 +13,9 @@ PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_MIN_DELAY=0.016

View File

@@ -10,7 +10,7 @@ COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=true
ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
# Expose port

View File

@@ -74,8 +74,11 @@
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型默认使用imagen-3.0
# 图片上传配置
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms、picgo、cloudflare_imgbed
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
PICGO_API_KEY="your-picogo-apikey" # PicoGo图床的API Key 可在 `https://www.picgo.net/settings/api` 获取
CLOUDFLARE_IMGBED_URL="https://xxxxxxx.pages.dev/upload" # CloudFlare 图床上传地址,可自行搭建:`https://github.com/MarSeventh/CloudFlare-ImgBed`
CLOUDFLARE_IMGBED_AUTH_CODE="your-cloudflare-imgber-auth-code" # CloudFlare图床的鉴权key可在项目后台设置若无鉴权则可直接置空。
# stream_optimizer 相关配置
STREAM_MIN_DELAY=0.016
@@ -138,10 +141,26 @@
- `UPLOAD_PROVIDER`: 图片上传服务提供商
- 默认值: `smms`
- 说明: 目前支持 SM.MS 图床
- 可选值: `smms`, `picgo`, `cloudflare_imgbed`
- 说明: 用于选择图片上传的服务提供商。目前支持 SM.MS 图床, PicGo 图床, 以及 Cloudflare ImgBed。
- `SMMS_SECRET_TOKEN`: SM.MS API Token
- 用途: 用于图片上传到 SM.MS 图床
- 获取方式: 需要在 SM.MS 官网注册并获取
- 用途: 用于图片上传到 SM.MS 图床的身份验证。
- 获取方式: 需要在 [SM.MS 官网](https://sm.ms/) 注册并获取
- `PICGO_API_KEY`: PicGo API Key
- 用途: 用于图片上传到 PicGo 图床的身份验证。
- 获取方式: 可在 [PicGo 官网](https://www.picgo.net/settings/api) 的设置页面 API 选项中获取。
- `CLOUDFLARE_IMGBED_URL`: Cloudflare ImgBed 上传地址
- 用途: 指定 Cloudflare ImgBed 图床的上传 API 地址。
- 获取方式: 如果您自行搭建了 Cloudflare ImgBed 服务,请填写您的服务部署地址。参考 [Cloudflare-ImgBed 项目](https://github.com/MarSeventh/CloudFlare-ImgBed) 自行搭建。
- 注意: URL 必须以 `https://` 开头,并指向 `/upload` 路径 ,例如 `https://cloudflare-imgbed-7b0.pages.dev/upload`。
- `CLOUDFLARE_IMGBED_AUTH_CODE`: Cloudflare ImgBed 鉴权 Key
- 用途: 用于 Cloudflare ImgBed 图床的身份验证。
- 说明: 如果您的 Cloudflare ImgBed 服务启用了鉴权,请填写鉴权 Key。若未启用鉴权则留空即可。
- 获取方式: 在 Cloudflare ImgBed 项目的后台设置中获取,或在搭建时自行设置。
#### 流式输出优化配置

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.core.config import settings
from app.core.logger import get_gemini_logger
from app.core.security import SecurityService
@@ -23,7 +23,7 @@ async def get_key_manager():
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
return await key_manager.get_next_working_key()
model_service = ModelService(settings.MODEL_SEARCH)
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
@router.get("/models")
@@ -36,12 +36,40 @@ async def list_models(_=Depends(security_service.verify_key),
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental",
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
"topP": 0.95, "topK": 64, "maxTemperature": 2})
# 模型名称以及对应的详细信息
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
# 添加搜索模型
if settings.MODEL_SEARCH:
for name in settings.MODEL_SEARCH:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if settings.MODEL_IMAGE:
for name in settings.MODEL_IMAGE:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json
@@ -62,6 +90,9 @@ async def generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
response = await chat_service.generate_content(
model=model_name,
@@ -92,6 +123,9 @@ async def stream_generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
response_stream = chat_service.stream_generate_content(
model=model_name,

View File

@@ -17,7 +17,7 @@ logger = get_openai_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.MODEL_SEARCH)
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService()
@@ -61,6 +61,10 @@ async def chat_completion(
logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
try:
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":

View File

@@ -7,6 +7,7 @@ class Settings(BaseSettings):
ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
MODEL_IMAGE: List[str] = ["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
@@ -16,6 +17,9 @@ class Settings(BaseSettings):
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置

View File

@@ -149,6 +149,229 @@ class QiniuUploader(ImageUploader):
pass
class PicGoUploader(ImageUploader):
"""Chevereto API 图片上传器"""
def __init__(self, api_key: str, api_url: str = "https://www.picgo.net/api/1/upload"):
"""
初始化 Chevereto 上传器
Args:
api_key: Chevereto API 密钥
api_url: Chevereto API 上传地址
"""
self.api_key = api_key
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到 Chevereto 服务
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求头
headers = {
"X-API-Key": self.api_key
}
# 准备文件数据
files = {
"source": (filename, file)
}
# 发送请求
response = requests.post(
self.api_url,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if result.get("status_code") != 200:
error_message = "Upload failed"
if "error" in result:
error_message = result["error"].get("message", error_message)
raise UploadError(
message=error_message,
error_type=UploadErrorType.SERVER_ERROR,
status_code=result.get("status_code"),
details=result.get("error")
)
# 从响应中提取图片信息
image_data = result.get("image", {})
# 构建图片元数据
image_metadata = ImageMetadata(
width=image_data.get("width", 0),
height=image_data.get("height", 0),
filename=image_data.get("filename", filename),
size=image_data.get("size", 0),
url=image_data.get("url", ""),
delete_url=image_data.get("delete_url", None)
)
return UploadResponse(
success=True,
code="success",
message=result.get("success", {}).get("message", "Upload success"),
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class CloudFlareImgBedUploader(ImageUploader):
"""CloudFlare图床上传器"""
def __init__(self, auth_code: str, api_url: str):
"""
初始化CloudFlare图床上传器
Args:
auth_code: 认证码
api_url: 上传API地址
"""
self.auth_code = auth_code
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到CloudFlare图床
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求URL添加认证码参数如果存在
if self.auth_code:
request_url = f"{self.api_url}?authCode={self.auth_code}"
else:
request_url = self.api_url
# 准备文件数据
files = {
"file": (filename, file)
}
# 发送请求
response = requests.post(
request_url,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证响应格式
if not result or not isinstance(result, list) or len(result) == 0:
raise UploadError(
message="Invalid response format",
error_type=UploadErrorType.PARSE_ERROR
)
# 获取文件URL
file_path = result[0].get("src")
if not file_path:
raise UploadError(
message="Missing file URL in response",
error_type=UploadErrorType.PARSE_ERROR
)
# 构建完整URL如果返回的是相对路径
base_url = self.api_url.split("/upload")[0]
full_url = file_path if file_path.startswith(("http://", "https://")) else f"{base_url}{file_path}"
# 构建图片元数据注意CloudFlare-ImgBed不返回所有元数据所以部分字段为默认值
image_metadata = ImageMetadata(
width=0, # CloudFlare-ImgBed不返回宽度
height=0, # CloudFlare-ImgBed不返回高度
filename=filename,
size=0, # CloudFlare-ImgBed不返回大小
url=full_url,
delete_url=None # CloudFlare-ImgBed不返回删除URL
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError, IndexError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class ImageUploaderFactory:
@staticmethod
def create(provider: str, **credentials) -> ImageUploader:
@@ -159,5 +382,12 @@ class ImageUploaderFactory:
credentials["access_key"],
credentials["secret_key"]
)
elif provider == "picgo":
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
return PicGoUploader(credentials["api_key"], api_url)
elif provider == "cloudflare_imgbed":
return CloudFlareImgBedUploader(
credentials["auth_code"],
credentials["base_url"]
)
raise ValueError(f"Unknown provider: {provider}")

View File

@@ -33,8 +33,8 @@ class GeminiContent(BaseModel):
class GeminiRequest(BaseModel):
contents: List[GeminiContent]
contents: List[GeminiContent] = []
tools: Optional[List[Dict[str, Any]]] = []
safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = None
generationConfig: Optional[GenerationConfig] = {}
systemInstruction: Optional[SystemInstruction] = None

View File

@@ -24,10 +24,18 @@ class GeminiApiClient(ApiClient):
self.base_url = base_url
self.timeout = timeout
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
def _get_real_model(self, model: str) -> str:
if model.endswith("-search"):
model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
return model
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
@@ -38,8 +46,8 @@ class GeminiApiClient(ApiClient):
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response:

View File

@@ -1,9 +1,13 @@
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
import re
from typing import Any, Dict, List, Optional
import requests
import base64
SUPPORTED_ROLES = ["user", "model", "system"]
IMAGE_URL_PATTERN = r'\[image\]\((.*?)\)'
class MessageConverter(ABC):
@@ -13,13 +17,36 @@ class MessageConverter(ABC):
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
def _get_mime_type_and_data(base64_string):
"""
从 base64 字符串中提取 MIME 类型和数据。
参数:
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
返回:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = r'data:([^;]+);base64,(.+)'
match = re.match(pattern, base64_string)
if match:
mime_type = match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
mime_type, encoded_data = _get_mime_type_and_data(image_url)
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
"mime_type": mime_type,
"data": encoded_data
}
}
return {
@@ -29,12 +56,62 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
}
def _convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
"""
处理可能包含图片URL的文本提取图片并转换为base64
Args:
text: 可能包含图片URL的文本
Returns:
List[Dict[str, Any]]: 包含文本和图片的部分列表
"""
parts = []
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(1)
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append({
"inlineData": {
"mimeType": "image/png",
"data": base64_data
}
})
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
else:
# 没有图片URL作为纯文本处理
parts.append({"text": text})
return parts
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction = None
system_instruction_parts = []
for idx, msg in enumerate(messages):
role = msg.get("role", "")
@@ -49,9 +126,18 @@ class OpenAIMessageConverter(MessageConverter):
role = "model"
parts = []
if isinstance(msg["content"], str) and msg["content"]:
# 特别处理最后一个assistant的消息按\n\n分割
if role == "assistant" and idx == len(messages) - 2 and isinstance(msg["content"], str) and msg["content"]:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.append({"text": msg["content"]})
parts.extend(_process_text_with_image(msg["content"]))
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
@@ -64,8 +150,16 @@ class OpenAIMessageConverter(MessageConverter):
if parts:
if role == "system":
system_instruction = {"role": "system", "parts": parts}
system_instruction_parts.extend(parts)
else:
converted_messages.append({"role": role, "parts": parts})
return converted_messages, system_instruction
system_instruction = (
None
if not system_instruction_parts
else {
"role": "system",
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction

View File

@@ -1,5 +1,6 @@
# app/services/chat/response_handler.py
import base64
import json
import random
import string
@@ -8,6 +9,7 @@ from typing import Dict, Any, List, Optional
import time
import uuid
from app.core.config import settings
from app.core.uploader import ImageUploaderFactory
class ResponseHandler(ABC):
@@ -135,67 +137,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# if "thinking" in model:
# if settings.SHOW_THINKING_PROCESS:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
if not parts:
return "", []
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
@@ -210,6 +153,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
elif "inlineData" in parts[0]:
text = _extract_image_data(parts[0])
else:
text = ""
text = _add_search_link_text(model, candidate, text)
@@ -235,14 +180,40 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = candidate["content"]["parts"][0]["text"]
else:
text = ""
for part in candidate["content"]["parts"]:
text += part.get("text", "")
if "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
if "text" in part:
text += part["text"]
elif "inlineData" in part:
text += _extract_image_data(part)
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
else:
text = "暂无返回"
return text, tool_calls
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
if upload_response.success:
text = f"![image]({upload_response.data.url})"
else:
text = ""
return text
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
"""提取工具调用信息"""
if not parts or not isinstance(parts, list):

View File

@@ -62,14 +62,19 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
payload = request.model_dump()
return {
"contents": payload.get("contents", []),
"tools": _build_tools(model, payload),
request_dict = request.model_dump()
payload = {
"contents": request_dict.get("contents", []),
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", [])
"generationConfig": request_dict.get("generationConfig", {}),
"systemInstruction": request_dict.get("systemInstruction", "")
}
if model.endswith("-image") or model.endswith("-image-generation"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
return payload
class GeminiChatService:

View File

@@ -96,11 +96,6 @@ class ImageCreateService:
for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
upload_response = image_uploader.upload(image_data,filename)
if request.response_format == "b64_json":
base64_image = base64.b64encode(image_data).decode('utf-8')
@@ -109,6 +104,30 @@ class ImageCreateService:
"revised_prompt": request.prompt
})
else:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.SMMS_SECRET_TOKEN
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
)
else:
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
upload_response = image_uploader.upload(image_data, filename)
images_data.append({
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt

View File

@@ -7,8 +7,9 @@ from app.core.config import settings
logger = get_model_logger()
class ModelService:
def __init__(self, model_search: list):
def __init__(self, model_search: list, model_image: list):
self.model_search = model_search
self.model_image = model_image
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
@@ -57,9 +58,27 @@ class ModelService:
search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model)
if model_id in self.model_image:
image_model = openai_model.copy()
image_model["id"] = f"{model_id}-image"
openai_format["data"].append(image_model)
if settings.CREATE_IMAGE_MODEL:
image_model = openai_model.copy()
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model)
return openai_format
def check_model_support(self, model: str) -> bool:
if not model or not isinstance(model, str):
return False
model = model.strip()
if model.endswith("-search"):
model = model[:-7]
return model in settings.MODEL_SEARCH
if model.endswith("-image"):
model = model[:-6]
return model in settings.MODEL_IMAGE
return True

View File

@@ -35,7 +35,7 @@ def _build_tools(
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation"))
and not _has_image_parts(messages)
):
tools.append({"code_execution": {}})
@@ -110,12 +110,16 @@ def _build_payload(
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
if (
instruction
and isinstance(instruction, dict)
and instruction.get("role") == "system"
and instruction.get("parts")
and not request.model.endswith("-image")
and not request.model.endswith("-image-generation")
):
payload["systemInstruction"] = instruction