Compare commits

..

7 Commits

Author SHA1 Message Date
snaily
cb5cd92041 fix: 修正Dockerfile中TOOLS_CODE_EXECUTION_ENABLED环境变量的拼写错误
将TOOLS_CODE_EXECUTION_ENABLED环境变量的值从"fasle"更正为"false",修复了拼写错误。
2025-03-14 13:46:31 +08:00
snaily
0be85e9536 feat(gemini_routes): 添加deepcopy导入
在gemini_routes.py中添加了Python标准库copy模块中的deepcopy函数导入,用于创建对象的深度副本,确保数据操作过程中不会意外修改原始对象。
2025-03-14 13:43:17 +08:00
Toddy
632dee38b3 check model before send request 2025-03-14 04:11:21 +00:00
Toddy
16c28bf1ba combine multiple system instructions into one 2025-03-14 02:55:29 +00:00
snaily
71af1db330 feat: 添加Gemini图像生成与处理功能
主要更新:

添加图像模型支持

新增MODEL_IMAGE配置项
在模型列表中添加gemini-2.0-flash-exp-image模型
修改ModelService以支持图像模型
增强图像处理能力

添加PicGoUploader类用于图像上传
实现图像响应处理逻辑(_extract_image_data)
支持base64图像数据的解码与上传
优化请求与响应处理

为图像模型添加特殊处理逻辑
修改API客户端以支持图像模型
更新GeminiRequest默认值
安全性调整

将TOOLS_CODE_EXECUTION_ENABLED默认设置为false
2025-03-14 00:27:23 +08:00
snaily
fb523f4a2e feat: 将 StreamOptimizer 参数改为可配置
将 StreamOptimizer 中的硬编码参数改为通过配置文件可配置的参数,提高了系统的灵活性。具体修改包括:

在 .env.example 中添加 stream_optimizer 相关配置参数
在 app/core/config.py 中添加对应的配置项
修改 app/services/chat/stream_optimizer.py 从配置中读取参数
在 README.md 中添加流式输出优化配置的详细说明
2025-03-06 16:56:01 +08:00
snaily
40e5ffa5f4 chore: Adjust StreamOptimizer parameters for improved performance
- Reduced long_text_threshold from 100 to 50 characters
- Decreased chunk_size from 10 to 5

These changes aim to optimize the streaming output for better user experience
and responsiveness, particularly for medium-length texts.
2025-03-06 16:45:35 +08:00
15 changed files with 317 additions and 96 deletions

View File

@@ -2,7 +2,8 @@ API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"] ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456 # AUTH_TOKEN=sk-123456
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"] MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
TOOLS_CODE_EXECUTION_ENABLED=true MODEL_IMAGE=["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED=false
SHOW_SEARCH_LINK=true SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true SHOW_THINKING_PROCESS=true
BASE_URL=https://generativelanguage.googleapis.com/v1beta BASE_URL=https://generativelanguage.googleapis.com/v1beta
@@ -12,4 +13,12 @@ PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002 CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
########################################################################## ##########################################################################

View File

@@ -10,7 +10,7 @@ COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]' ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]' ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=true ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]' ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
# Expose port # Expose port

View File

@@ -76,6 +76,13 @@
# 图片上传配置 # 图片上传配置
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
# stream_optimizer 相关配置
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
``` ```
### 配置说明 ### 配置说明
@@ -136,6 +143,24 @@
- 用途: 用于图片上传到 SM.MS 图床 - 用途: 用于图片上传到 SM.MS 图床
- 获取方式: 需要在 SM.MS 官网注册并获取 - 获取方式: 需要在 SM.MS 官网注册并获取
#### 流式输出优化配置
- `STREAM_MIN_DELAY`: 最小延迟时间
- 默认值: `0.016`(秒)
- 说明: 长文本输出时使用的最小延迟时间,值越小输出速度越快
- `STREAM_MAX_DELAY`: 最大延迟时间
- 默认值: `0.024`(秒)
- 说明: 短文本输出时使用的最大延迟时间,值越大输出速度越慢
- `STREAM_SHORT_TEXT_THRESHOLD`: 短文本阈值
- 默认值: `10`(字符)
- 说明: 小于此长度的文本被视为短文本,将使用最大延迟输出
- `STREAM_LONG_TEXT_THRESHOLD`: 长文本阈值
- 默认值: `50`(字符)
- 说明: 大于此长度的文本被视为长文本,将使用最小延迟并分块输出
- `STREAM_CHUNK_SIZE`: 长文本分块大小
- 默认值: `5`(字符)
- 说明: 长文本分块输出时,每个块的大小
### ▶️ 运行 ### ▶️ 运行
#### 使用 Docker (推荐) #### 使用 Docker (推荐)

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.core.config import settings from app.core.config import settings
from app.core.logger import get_gemini_logger from app.core.logger import get_gemini_logger
from app.core.security import SecurityService from app.core.security import SecurityService
@@ -23,7 +23,7 @@ async def get_key_manager():
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)): async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
return await key_manager.get_next_working_key() return await key_manager.get_next_working_key()
model_service = ModelService(settings.MODEL_SEARCH) model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
@router.get("/models") @router.get("/models")
@@ -36,12 +36,40 @@ async def list_models(_=Depends(security_service.verify_key),
api_key = await key_manager.get_next_working_key() api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key) models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental", # 模型名称以及对应的详细信息
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767, model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, # 添加搜索模型
"topP": 0.95, "topK": 64, "maxTemperature": 2}) if settings.MODEL_SEARCH:
for name in settings.MODEL_SEARCH:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if settings.MODEL_IMAGE:
for name in settings.MODEL_IMAGE:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json return models_json
@@ -62,6 +90,9 @@ async def generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try: try:
response = await chat_service.generate_content( response = await chat_service.generate_content(
model=model_name, model=model_name,
@@ -92,6 +123,9 @@ async def stream_generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try: try:
response_stream = chat_service.stream_generate_content( response_stream = chat_service.stream_generate_content(
model=model_name, model=model_name,

View File

@@ -17,7 +17,7 @@ logger = get_openai_logger()
# 初始化服务 # 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.MODEL_SEARCH) model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
embedding_service = EmbeddingService(settings.BASE_URL) embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService() image_create_service = ImageCreateService()
@@ -61,6 +61,10 @@ async def chat_completion(
logger.info(f"Handling chat completion request for model: {request.model}") logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
try: try:
# 如果model是imagen3,使用paid_key # 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":

View File

@@ -7,6 +7,7 @@ class Settings(BaseSettings):
ALLOWED_TOKENS: List[str] ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta" BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"] MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
MODEL_IMAGE: List[str] = ["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED: bool = False TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True SHOW_THINKING_PROCESS: bool = True
@@ -16,7 +17,15 @@ class Settings(BaseSettings):
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002" CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
UPLOAD_PROVIDER: str = "smms" UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = "" SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
TEST_MODEL: str = "gemini-1.5-flash" TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置
STREAM_MIN_DELAY: float = 0.016
STREAM_MAX_DELAY: float = 0.024
STREAM_SHORT_TEXT_THRESHOLD: int = 10
STREAM_LONG_TEXT_THRESHOLD: int = 50
STREAM_CHUNK_SIZE: int = 5
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -149,6 +149,116 @@ class QiniuUploader(ImageUploader):
pass pass
class PicGoUploader(ImageUploader):
"""Chevereto API 图片上传器"""
def __init__(self, api_key: str, api_url: str = "https://www.picgo.net/api/1/upload"):
"""
初始化 Chevereto 上传器
Args:
api_key: Chevereto API 密钥
api_url: Chevereto API 上传地址
"""
self.api_key = api_key
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到 Chevereto 服务
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求头
headers = {
"X-API-Key": self.api_key
}
# 准备文件数据
files = {
"source": (filename, file)
}
# 发送请求
response = requests.post(
self.api_url,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if result.get("status_code") != 200:
error_message = "Upload failed"
if "error" in result:
error_message = result["error"].get("message", error_message)
raise UploadError(
message=error_message,
error_type=UploadErrorType.SERVER_ERROR,
status_code=result.get("status_code"),
details=result.get("error")
)
# 从响应中提取图片信息
image_data = result.get("image", {})
# 构建图片元数据
image_metadata = ImageMetadata(
width=image_data.get("width", 0),
height=image_data.get("height", 0),
filename=image_data.get("filename", filename),
size=image_data.get("size", 0),
url=image_data.get("url", ""),
delete_url=image_data.get("delete_url", None)
)
return UploadResponse(
success=True,
code="success",
message=result.get("success", {}).get("message", "Upload success"),
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class ImageUploaderFactory: class ImageUploaderFactory:
@staticmethod @staticmethod
def create(provider: str, **credentials) -> ImageUploader: def create(provider: str, **credentials) -> ImageUploader:
@@ -159,5 +269,7 @@ class ImageUploaderFactory:
credentials["access_key"], credentials["access_key"],
credentials["secret_key"] credentials["secret_key"]
) )
elif provider == "picgo":
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
return PicGoUploader(credentials["api_key"], api_url)
raise ValueError(f"Unknown provider: {provider}") raise ValueError(f"Unknown provider: {provider}")

View File

@@ -33,8 +33,8 @@ class GeminiContent(BaseModel):
class GeminiRequest(BaseModel): class GeminiRequest(BaseModel):
contents: List[GeminiContent] contents: List[GeminiContent] = []
tools: Optional[List[Dict[str, Any]]] = [] tools: Optional[List[Dict[str, Any]]] = []
safetySettings: Optional[List[SafetySetting]] = None safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = None generationConfig: Optional[GenerationConfig] = {}
systemInstruction: Optional[SystemInstruction] = None systemInstruction: Optional[SystemInstruction] = None

View File

@@ -24,10 +24,18 @@ class GeminiApiClient(ApiClient):
self.base_url = base_url self.base_url = base_url
self.timeout = timeout self.timeout = timeout
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: def _get_real_model(self, model: str) -> str:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"): if model.endswith("-search"):
model = model[:-7] model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
return model
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = await client.post(url, json=payload) response = await client.post(url, json=payload)
@@ -38,8 +46,8 @@ class GeminiApiClient(ApiClient):
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout) timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"): model = self._get_real_model(model)
model = model[:-7]
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response: async with client.stream(method="POST", url=url, json=payload) as response:

View File

@@ -34,7 +34,7 @@ class OpenAIMessageConverter(MessageConverter):
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = [] converted_messages = []
system_instruction = None system_instruction_parts = []
for idx, msg in enumerate(messages): for idx, msg in enumerate(messages):
role = msg.get("role", "") role = msg.get("role", "")
@@ -64,8 +64,16 @@ class OpenAIMessageConverter(MessageConverter):
if parts: if parts:
if role == "system": if role == "system":
system_instruction = {"role": "system", "parts": parts} system_instruction_parts.extend(parts)
else: else:
converted_messages.append({"role": role, "parts": parts}) converted_messages.append({"role": role, "parts": parts})
system_instruction = (
None
if not system_instruction_parts
else {
"role": "system",
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction return converted_messages, system_instruction

View File

@@ -1,5 +1,6 @@
# app/services/chat/response_handler.py # app/services/chat/response_handler.py
import base64
import json import json
import random import random
import string import string
@@ -8,6 +9,7 @@ from typing import Dict, Any, List, Optional
import time import time
import uuid import uuid
from app.core.config import settings from app.core.config import settings
from app.core.uploader import ImageUploaderFactory
class ResponseHandler(ABC): class ResponseHandler(ABC):
@@ -135,67 +137,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
candidate = response["candidates"][0] candidate = response["candidates"][0]
content = candidate.get("content", {}) content = candidate.get("content", {})
parts = content.get("parts", []) parts = content.get("parts", [])
# if "thinking" in model: if not parts:
# if settings.SHOW_THINKING_PROCESS: return "", []
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
if "text" in parts[0]: if "text" in parts[0]:
text = parts[0].get("text") text = parts[0].get("text")
elif "executableCode" in parts[0]: elif "executableCode" in parts[0]:
@@ -210,6 +153,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = _format_execution_result( text = _format_execution_result(
parts[0]["codeExecutionResult"] parts[0]["codeExecutionResult"]
) )
elif "inlineData" in parts[0]:
text = _extract_image_data(parts[0])
else: else:
text = "" text = ""
text = _add_search_link_text(model, candidate, text) text = _add_search_link_text(model, candidate, text)
@@ -235,14 +180,38 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = candidate["content"]["parts"][0]["text"] text = candidate["content"]["parts"][0]["text"]
else: else:
text = "" text = ""
for part in candidate["content"]["parts"]: if "parts" in candidate["content"]:
text += part.get("text", "") for part in candidate["content"]["parts"]:
if "text" in part:
text += part["text"]
elif "inlineData" in part:
text += _extract_image_data(part)
text = _add_search_link_text(model, candidate, text) text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format) tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
else: else:
text = "暂无返回" text = "暂无返回"
return text, tool_calls return text, tool_calls
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
if upload_response.success:
text = f"\n![image]({upload_response.data.url})\n"
else:
text = ""
return text
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]: def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
"""提取工具调用信息""" """提取工具调用信息"""
if not parts or not isinstance(parts, list): if not parts or not isinstance(parts, list):

View File

@@ -4,6 +4,7 @@ import asyncio
import math import math
from typing import Any, List, AsyncGenerator, Callable from typing import Any, List, AsyncGenerator, Callable
from app.core.logger import get_openai_logger, get_gemini_logger from app.core.logger import get_openai_logger, get_gemini_logger
from app.core.config import settings
logger_openai = get_openai_logger() logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger() logger_gemini = get_gemini_logger()
@@ -20,8 +21,8 @@ class StreamOptimizer:
min_delay: float = 0.016, min_delay: float = 0.016,
max_delay: float = 0.024, max_delay: float = 0.024,
short_text_threshold: int = 10, short_text_threshold: int = 10,
long_text_threshold: int = 100, long_text_threshold: int = 50,
chunk_size: int = 10): chunk_size: int = 5):
"""初始化流式输出优化器 """初始化流式输出优化器
参数: 参数:
@@ -112,5 +113,20 @@ class StreamOptimizer:
# 创建默认的优化器实例,可以直接导入使用 # 创建默认的优化器实例,可以直接导入使用
openai_optimizer = StreamOptimizer(logger=logger_openai) openai_optimizer = StreamOptimizer(
gemini_optimizer = StreamOptimizer(logger=logger_gemini) logger=logger_openai,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)
gemini_optimizer = StreamOptimizer(
logger=logger_gemini,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)

View File

@@ -62,14 +62,19 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload""" """构建请求payload"""
payload = request.model_dump() request_dict = request.model_dump()
return { payload = {
"contents": payload.get("contents", []), "contents": request_dict.get("contents", []),
"tools": _build_tools(model, payload), "tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model), "safetySettings": _get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}), "generationConfig": request_dict.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", []) "systemInstruction": request_dict.get("systemInstruction", "")
} }
if model.endswith("-image"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
return payload
class GeminiChatService: class GeminiChatService:

View File

@@ -7,8 +7,9 @@ from app.core.config import settings
logger = get_model_logger() logger = get_model_logger()
class ModelService: class ModelService:
def __init__(self, model_search: list): def __init__(self, model_search: list, model_image: list):
self.model_search = model_search self.model_search = model_search
self.model_image = model_image
self.base_url = "https://generativelanguage.googleapis.com/v1beta" self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]: def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
@@ -57,9 +58,27 @@ class ModelService:
search_model = openai_model.copy() search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search" search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model) openai_format["data"].append(search_model)
if model_id in self.model_image:
image_model = openai_model.copy()
image_model["id"] = f"{model_id}-image"
openai_format["data"].append(image_model)
if settings.CREATE_IMAGE_MODEL: if settings.CREATE_IMAGE_MODEL:
image_model = openai_model.copy() image_model = openai_model.copy()
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat" image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model) openai_format["data"].append(image_model)
return openai_format return openai_format
def check_model_support(self, model: str) -> bool:
if not model or not isinstance(model, str):
return False
model = model.strip()
if model.endswith("-search"):
model = model[:-7]
return model in settings.MODEL_SEARCH
if model.endswith("-image"):
model = model[:-6]
return model in settings.MODEL_IMAGE
return True

View File

@@ -35,7 +35,7 @@ def _build_tools(
if ( if (
settings.TOOLS_CODE_EXECUTION_ENABLED settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model) and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image"))
and not _has_image_parts(messages) and not _has_image_parts(messages)
): ):
tools.append({"code_execution": {}}) tools.append({"code_execution": {}})
@@ -110,12 +110,15 @@ def _build_payload(
"tools": _build_tools(request, messages), "tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model), "safetySettings": _get_safety_settings(request.model),
} }
if request.model.endswith("-image"):
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
if ( if (
instruction instruction
and isinstance(instruction, dict) and isinstance(instruction, dict)
and instruction.get("role") == "system" and instruction.get("role") == "system"
and instruction.get("parts") and instruction.get("parts")
and not request.model.endswith("-image")
): ):
payload["systemInstruction"] = instruction payload["systemInstruction"] = instruction