Compare commits

..

7 Commits

Author SHA1 Message Date
snaily
cb5cd92041 fix: 修正Dockerfile中TOOLS_CODE_EXECUTION_ENABLED环境变量的拼写错误
将TOOLS_CODE_EXECUTION_ENABLED环境变量的值从"fasle"更正为"false",修复了拼写错误。
2025-03-14 13:46:31 +08:00
snaily
0be85e9536 feat(gemini_routes): 添加deepcopy导入
在gemini_routes.py中添加了Python标准库copy模块中的deepcopy函数导入,用于创建对象的深度副本,确保数据操作过程中不会意外修改原始对象。
2025-03-14 13:43:17 +08:00
Toddy
632dee38b3 check model before send request 2025-03-14 04:11:21 +00:00
Toddy
16c28bf1ba combine multiple system instructions into one 2025-03-14 02:55:29 +00:00
snaily
71af1db330 feat: 添加Gemini图像生成与处理功能
主要更新:

添加图像模型支持

新增MODEL_IMAGE配置项
在模型列表中添加gemini-2.0-flash-exp-image模型
修改ModelService以支持图像模型
增强图像处理能力

添加PicGoUploader类用于图像上传
实现图像响应处理逻辑(_extract_image_data)
支持base64图像数据的解码与上传
优化请求与响应处理

为图像模型添加特殊处理逻辑
修改API客户端以支持图像模型
更新GeminiRequest默认值
安全性调整

将TOOLS_CODE_EXECUTION_ENABLED默认设置为false
2025-03-14 00:27:23 +08:00
snaily
fb523f4a2e feat: 将 StreamOptimizer 参数改为可配置
将 StreamOptimizer 中的硬编码参数改为通过配置文件可配置的参数,提高了系统的灵活性。具体修改包括:

在 .env.example 中添加 stream_optimizer 相关配置参数
在 app/core/config.py 中添加对应的配置项
修改 app/services/chat/stream_optimizer.py 从配置中读取参数
在 README.md 中添加流式输出优化配置的详细说明
2025-03-06 16:56:01 +08:00
snaily
40e5ffa5f4 chore: Adjust StreamOptimizer parameters for improved performance
- Reduced long_text_threshold from 100 to 50 characters
- Decreased chunk_size from 10 to 5

These changes aim to optimize the streaming output for better user experience
and responsiveness, particularly for medium-length texts.
2025-03-06 16:45:35 +08:00
15 changed files with 317 additions and 96 deletions

View File

@@ -2,7 +2,8 @@ API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
TOOLS_CODE_EXECUTION_ENABLED=true
MODEL_IMAGE=["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED=false
SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true
BASE_URL=https://generativelanguage.googleapis.com/v1beta
@@ -12,4 +13,12 @@ PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
##########################################################################

View File

@@ -10,7 +10,7 @@ COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=true
ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
# Expose port

View File

@@ -76,6 +76,13 @@
# 图片上传配置
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
# stream_optimizer 相关配置
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
```
### 配置说明
@@ -136,6 +143,24 @@
- 用途: 用于图片上传到 SM.MS 图床
- 获取方式: 需要在 SM.MS 官网注册并获取
#### 流式输出优化配置
- `STREAM_MIN_DELAY`: 最小延迟时间
- 默认值: `0.016`(秒)
- 说明: 长文本输出时使用的最小延迟时间,值越小输出速度越快
- `STREAM_MAX_DELAY`: 最大延迟时间
- 默认值: `0.024`(秒)
- 说明: 短文本输出时使用的最大延迟时间,值越大输出速度越慢
- `STREAM_SHORT_TEXT_THRESHOLD`: 短文本阈值
- 默认值: `10`(字符)
- 说明: 小于此长度的文本被视为短文本,将使用最大延迟输出
- `STREAM_LONG_TEXT_THRESHOLD`: 长文本阈值
- 默认值: `50`(字符)
- 说明: 大于此长度的文本被视为长文本,将使用最小延迟并分块输出
- `STREAM_CHUNK_SIZE`: 长文本分块大小
- 默认值: `5`(字符)
- 说明: 长文本分块输出时,每个块的大小
### ▶️ 运行
#### 使用 Docker (推荐)

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.core.config import settings
from app.core.logger import get_gemini_logger
from app.core.security import SecurityService
@@ -23,7 +23,7 @@ async def get_key_manager():
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
return await key_manager.get_next_working_key()
model_service = ModelService(settings.MODEL_SEARCH)
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
@router.get("/models")
@@ -36,12 +36,40 @@ async def list_models(_=Depends(security_service.verify_key),
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental",
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
"topP": 0.95, "topK": 64, "maxTemperature": 2})
# 模型名称以及对应的详细信息
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
# 添加搜索模型
if settings.MODEL_SEARCH:
for name in settings.MODEL_SEARCH:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if settings.MODEL_IMAGE:
for name in settings.MODEL_IMAGE:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json
@@ -62,6 +90,9 @@ async def generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
response = await chat_service.generate_content(
model=model_name,
@@ -92,6 +123,9 @@ async def stream_generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
response_stream = chat_service.stream_generate_content(
model=model_name,

View File

@@ -17,7 +17,7 @@ logger = get_openai_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.MODEL_SEARCH)
model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService()
@@ -61,6 +61,10 @@ async def chat_completion(
logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
try:
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":

View File

@@ -7,6 +7,7 @@ class Settings(BaseSettings):
ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
MODEL_IMAGE: List[str] = ["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
@@ -16,7 +17,15 @@ class Settings(BaseSettings):
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置
STREAM_MIN_DELAY: float = 0.016
STREAM_MAX_DELAY: float = 0.024
STREAM_SHORT_TEXT_THRESHOLD: int = 10
STREAM_LONG_TEXT_THRESHOLD: int = 50
STREAM_CHUNK_SIZE: int = 5
def __init__(self):
super().__init__()

View File

@@ -149,6 +149,116 @@ class QiniuUploader(ImageUploader):
pass
class PicGoUploader(ImageUploader):
"""Chevereto API 图片上传器"""
def __init__(self, api_key: str, api_url: str = "https://www.picgo.net/api/1/upload"):
"""
初始化 Chevereto 上传器
Args:
api_key: Chevereto API 密钥
api_url: Chevereto API 上传地址
"""
self.api_key = api_key
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到 Chevereto 服务
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求头
headers = {
"X-API-Key": self.api_key
}
# 准备文件数据
files = {
"source": (filename, file)
}
# 发送请求
response = requests.post(
self.api_url,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if result.get("status_code") != 200:
error_message = "Upload failed"
if "error" in result:
error_message = result["error"].get("message", error_message)
raise UploadError(
message=error_message,
error_type=UploadErrorType.SERVER_ERROR,
status_code=result.get("status_code"),
details=result.get("error")
)
# 从响应中提取图片信息
image_data = result.get("image", {})
# 构建图片元数据
image_metadata = ImageMetadata(
width=image_data.get("width", 0),
height=image_data.get("height", 0),
filename=image_data.get("filename", filename),
size=image_data.get("size", 0),
url=image_data.get("url", ""),
delete_url=image_data.get("delete_url", None)
)
return UploadResponse(
success=True,
code="success",
message=result.get("success", {}).get("message", "Upload success"),
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class ImageUploaderFactory:
@staticmethod
def create(provider: str, **credentials) -> ImageUploader:
@@ -159,5 +269,7 @@ class ImageUploaderFactory:
credentials["access_key"],
credentials["secret_key"]
)
elif provider == "picgo":
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
return PicGoUploader(credentials["api_key"], api_url)
raise ValueError(f"Unknown provider: {provider}")

View File

@@ -33,8 +33,8 @@ class GeminiContent(BaseModel):
class GeminiRequest(BaseModel):
contents: List[GeminiContent]
contents: List[GeminiContent] = []
tools: Optional[List[Dict[str, Any]]] = []
safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = None
generationConfig: Optional[GenerationConfig] = {}
systemInstruction: Optional[SystemInstruction] = None

View File

@@ -24,10 +24,18 @@ class GeminiApiClient(ApiClient):
self.base_url = base_url
self.timeout = timeout
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
def _get_real_model(self, model: str) -> str:
if model.endswith("-search"):
model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
return model
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
@@ -38,8 +46,8 @@ class GeminiApiClient(ApiClient):
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response:

View File

@@ -34,7 +34,7 @@ class OpenAIMessageConverter(MessageConverter):
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction = None
system_instruction_parts = []
for idx, msg in enumerate(messages):
role = msg.get("role", "")
@@ -64,8 +64,16 @@ class OpenAIMessageConverter(MessageConverter):
if parts:
if role == "system":
system_instruction = {"role": "system", "parts": parts}
system_instruction_parts.extend(parts)
else:
converted_messages.append({"role": role, "parts": parts})
system_instruction = (
None
if not system_instruction_parts
else {
"role": "system",
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction

View File

@@ -1,5 +1,6 @@
# app/services/chat/response_handler.py
import base64
import json
import random
import string
@@ -8,6 +9,7 @@ from typing import Dict, Any, List, Optional
import time
import uuid
from app.core.config import settings
from app.core.uploader import ImageUploaderFactory
class ResponseHandler(ABC):
@@ -135,67 +137,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# if "thinking" in model:
# if settings.SHOW_THINKING_PROCESS:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
if not parts:
return "", []
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
@@ -210,6 +153,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
elif "inlineData" in parts[0]:
text = _extract_image_data(parts[0])
else:
text = ""
text = _add_search_link_text(model, candidate, text)
@@ -235,14 +180,38 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = candidate["content"]["parts"][0]["text"]
else:
text = ""
for part in candidate["content"]["parts"]:
text += part.get("text", "")
if "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
if "text" in part:
text += part["text"]
elif "inlineData" in part:
text += _extract_image_data(part)
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
else:
text = "暂无返回"
return text, tool_calls
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
if upload_response.success:
text = f"\n![image]({upload_response.data.url})\n"
else:
text = ""
return text
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
"""提取工具调用信息"""
if not parts or not isinstance(parts, list):

View File

@@ -4,6 +4,7 @@ import asyncio
import math
from typing import Any, List, AsyncGenerator, Callable
from app.core.logger import get_openai_logger, get_gemini_logger
from app.core.config import settings
logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger()
@@ -20,8 +21,8 @@ class StreamOptimizer:
min_delay: float = 0.016,
max_delay: float = 0.024,
short_text_threshold: int = 10,
long_text_threshold: int = 100,
chunk_size: int = 10):
long_text_threshold: int = 50,
chunk_size: int = 5):
"""初始化流式输出优化器
参数:
@@ -112,5 +113,20 @@ class StreamOptimizer:
# 创建默认的优化器实例,可以直接导入使用
openai_optimizer = StreamOptimizer(logger=logger_openai)
gemini_optimizer = StreamOptimizer(logger=logger_gemini)
openai_optimizer = StreamOptimizer(
logger=logger_openai,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)
gemini_optimizer = StreamOptimizer(
logger=logger_gemini,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)

View File

@@ -62,14 +62,19 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
payload = request.model_dump()
return {
"contents": payload.get("contents", []),
"tools": _build_tools(model, payload),
request_dict = request.model_dump()
payload = {
"contents": request_dict.get("contents", []),
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", [])
"generationConfig": request_dict.get("generationConfig", {}),
"systemInstruction": request_dict.get("systemInstruction", "")
}
if model.endswith("-image"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
return payload
class GeminiChatService:

View File

@@ -7,8 +7,9 @@ from app.core.config import settings
logger = get_model_logger()
class ModelService:
def __init__(self, model_search: list):
def __init__(self, model_search: list, model_image: list):
self.model_search = model_search
self.model_image = model_image
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
@@ -57,9 +58,27 @@ class ModelService:
search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model)
if model_id in self.model_image:
image_model = openai_model.copy()
image_model["id"] = f"{model_id}-image"
openai_format["data"].append(image_model)
if settings.CREATE_IMAGE_MODEL:
image_model = openai_model.copy()
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model)
return openai_format
def check_model_support(self, model: str) -> bool:
if not model or not isinstance(model, str):
return False
model = model.strip()
if model.endswith("-search"):
model = model[:-7]
return model in settings.MODEL_SEARCH
if model.endswith("-image"):
model = model[:-6]
return model in settings.MODEL_IMAGE
return True

View File

@@ -35,7 +35,7 @@ def _build_tools(
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image"))
and not _has_image_parts(messages)
):
tools.append({"code_execution": {}})
@@ -110,12 +110,15 @@ def _build_payload(
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
if request.model.endswith("-image"):
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
if (
instruction
and isinstance(instruction, dict)
and instruction.get("role") == "system"
and instruction.get("parts")
and not request.model.endswith("-image")
):
payload["systemInstruction"] = instruction