Compare commits

...

16 Commits

Author SHA1 Message Date
snaily
67c85c994a Merge pull request #14 from cr-zhichen/main
fix: 更新Cloudflare ImgBed上传请求URL,新增uploadNameType参数,以保持正确的目录结构命名。
2025-03-17 15:24:39 +08:00
cr-zhichen
ee979dd568 Merge branch 'main' of https://github.com/cr-zhichen/gemini-balance 2025-03-17 07:12:43 +00:00
cr-zhichen
e79a1ba56c feat: 更新CloudFlare ImgBed上传请求URL,新增uploadNameType参数,以保持正确的日期命名目录结构。 2025-03-17 07:10:21 +00:00
snaily
8779a5f0b3 feat: 添加对 image-generation 模型的支持
在 gemini_chat_service 和 openai_chat_service 中添加对 "-image-generation" 后缀模型的支持
确保 image-generation 模型与 image 模型有相同的处理逻辑
2025-03-16 23:53:53 +08:00
cr-zhichen
89f2825ac7 feat: 新增对CloudFlare ImgBed的支持,更新环境变量和文档 2025-03-16 04:39:40 +00:00
snaily
985a12554d fix:修改OpenAI消息转换器中assistant消息处理逻辑,将特殊处理的目标从最后一条消息调整为倒数第二条消息。 2025-03-15 21:18:20 +08:00
snaily
37a7a140fc feat:改进消息转换器中的图像处理和消息分割逻辑
添加 _get_mime_type_and_data 函数从 base64 字符串中提取 MIME 类型和数据
修改 _convert_image 函数使用动态检测的 MIME 类型,而非硬编码
将 _process_text_with_image 中的 MIME 类型从 image/jpeg 改为 image/png
简化异常处理逻辑
优化 OpenAIMessageConverter 中的消息分割逻辑,仅对最后一个 assistant 消息进行分割处理
2025-03-15 21:11:10 +08:00
zhanghaoyu
28e67cc3fa 1. modify IMAGE_URL_PATTERN
2. modify import
2025-03-15 12:37:56 +08:00
zhanghaoyu7
d99a0bde93 feat: 新增图文上下文同步 2025-03-14 16:29:03 +08:00
snaily
cb5cd92041 fix: 修正Dockerfile中TOOLS_CODE_EXECUTION_ENABLED环境变量的拼写错误
将TOOLS_CODE_EXECUTION_ENABLED环境变量的值从"fasle"更正为"false",修复了拼写错误。
2025-03-14 13:46:31 +08:00
snaily
0be85e9536 feat(gemini_routes): 添加deepcopy导入
在gemini_routes.py中添加了Python标准库copy模块中的deepcopy函数导入,用于创建对象的深度副本,确保数据操作过程中不会意外修改原始对象。
2025-03-14 13:43:17 +08:00
Toddy
632dee38b3 check model before send request 2025-03-14 04:11:21 +00:00
Toddy
16c28bf1ba combine multiple system instructions into one 2025-03-14 02:55:29 +00:00
snaily
71af1db330 feat: 添加Gemini图像生成与处理功能
主要更新:

添加图像模型支持

新增MODEL_IMAGE配置项
在模型列表中添加gemini-2.0-flash-exp-image模型
修改ModelService以支持图像模型
增强图像处理能力

添加PicGoUploader类用于图像上传
实现图像响应处理逻辑(_extract_image_data)
支持base64图像数据的解码与上传
优化请求与响应处理

为图像模型添加特殊处理逻辑
修改API客户端以支持图像模型
更新GeminiRequest默认值
安全性调整

将TOOLS_CODE_EXECUTION_ENABLED默认设置为false
2025-03-14 00:27:23 +08:00
snaily
fb523f4a2e feat: 将 StreamOptimizer 参数改为可配置
将 StreamOptimizer 中的硬编码参数改为通过配置文件可配置的参数,提高了系统的灵活性。具体修改包括:

在 .env.example 中添加 stream_optimizer 相关配置参数
在 app/core/config.py 中添加对应的配置项
修改 app/services/chat/stream_optimizer.py 从配置中读取参数
在 README.md 中添加流式输出优化配置的详细说明
2025-03-06 16:56:01 +08:00
snaily
40e5ffa5f4 chore: Adjust StreamOptimizer parameters for improved performance
- Reduced long_text_threshold from 100 to 50 characters
- Decreased chunk_size from 10 to 5

These changes aim to optimize the streaming output for better user experience
and responsiveness, particularly for medium-length texts.
2025-03-06 16:45:35 +08:00
16 changed files with 580 additions and 110 deletions

View File

@@ -2,7 +2,8 @@ API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"] ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456 # AUTH_TOKEN=sk-123456
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"] MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
TOOLS_CODE_EXECUTION_ENABLED=true MODEL_IMAGE=["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED=false
SHOW_SEARCH_LINK=true SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true SHOW_THINKING_PROCESS=true
BASE_URL=https://generativelanguage.googleapis.com/v1beta BASE_URL=https://generativelanguage.googleapis.com/v1beta
@@ -12,4 +13,14 @@ PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002 CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
########################################################################## ##########################################################################

View File

@@ -10,7 +10,7 @@ COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]' ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]' ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=true ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]' ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
# Expose port # Expose port

View File

@@ -74,8 +74,18 @@
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型默认使用imagen-3.0 CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型默认使用imagen-3.0
# 图片上传配置 # 图片上传配置
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms、picgo、cloudflare_imgbed
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
PICGO_API_KEY="your-picogo-apikey" # PicoGo图床的API Key 可在 `https://www.picgo.net/settings/api` 获取
CLOUDFLARE_IMGBED_URL="https://xxxxxxx.pages.dev/upload" # CloudFlare 图床上传地址,可自行搭建:`https://github.com/MarSeventh/CloudFlare-ImgBed`
CLOUDFLARE_IMGBED_AUTH_CODE="your-cloudflare-imgber-auth-code" # CloudFlare图床的鉴权key可在项目后台设置若无鉴权则可直接置空。
# stream_optimizer 相关配置
STREAM_MIN_DELAY=0.016
STREAM_MAX_DELAY=0.024
STREAM_SHORT_TEXT_THRESHOLD=10
STREAM_LONG_TEXT_THRESHOLD=50
STREAM_CHUNK_SIZE=5
``` ```
### 配置说明 ### 配置说明
@@ -131,10 +141,44 @@
- `UPLOAD_PROVIDER`: 图片上传服务提供商 - `UPLOAD_PROVIDER`: 图片上传服务提供商
- 默认值: `smms` - 默认值: `smms`
- 说明: 目前支持 SM.MS 图床 - 可选值: `smms`, `picgo`, `cloudflare_imgbed`
- 说明: 用于选择图片上传的服务提供商。目前支持 SM.MS 图床, PicGo 图床, 以及 Cloudflare ImgBed。
- `SMMS_SECRET_TOKEN`: SM.MS API Token - `SMMS_SECRET_TOKEN`: SM.MS API Token
- 用途: 用于图片上传到 SM.MS 图床 - 用途: 用于图片上传到 SM.MS 图床的身份验证。
- 获取方式: 需要在 SM.MS 官网注册并获取 - 获取方式: 需要在 [SM.MS 官网](https://sm.ms/) 注册并获取
- `PICGO_API_KEY`: PicGo API Key
- 用途: 用于图片上传到 PicGo 图床的身份验证。
- 获取方式: 可在 [PicGo 官网](https://www.picgo.net/settings/api) 的设置页面 API 选项中获取。
- `CLOUDFLARE_IMGBED_URL`: Cloudflare ImgBed 上传地址
- 用途: 指定 Cloudflare ImgBed 图床的上传 API 地址。
- 获取方式: 如果您自行搭建了 Cloudflare ImgBed 服务,请填写您的服务部署地址。参考 [Cloudflare-ImgBed 项目](https://github.com/MarSeventh/CloudFlare-ImgBed) 自行搭建。
- 注意: URL 必须以 `https://` 开头,并指向 `/upload` 路径 ,例如 `https://cloudflare-imgbed-7b0.pages.dev/upload`。
- `CLOUDFLARE_IMGBED_AUTH_CODE`: Cloudflare ImgBed 鉴权 Key
- 用途: 用于 Cloudflare ImgBed 图床的身份验证。
- 说明: 如果您的 Cloudflare ImgBed 服务启用了鉴权,请填写鉴权 Key。若未启用鉴权则留空即可。
- 获取方式: 在 Cloudflare ImgBed 项目的后台设置中获取,或在搭建时自行设置。
#### 流式输出优化配置
- `STREAM_MIN_DELAY`: 最小延迟时间
- 默认值: `0.016`(秒)
- 说明: 长文本输出时使用的最小延迟时间,值越小输出速度越快
- `STREAM_MAX_DELAY`: 最大延迟时间
- 默认值: `0.024`(秒)
- 说明: 短文本输出时使用的最大延迟时间,值越大输出速度越慢
- `STREAM_SHORT_TEXT_THRESHOLD`: 短文本阈值
- 默认值: `10`(字符)
- 说明: 小于此长度的文本被视为短文本,将使用最大延迟输出
- `STREAM_LONG_TEXT_THRESHOLD`: 长文本阈值
- 默认值: `50`(字符)
- 说明: 大于此长度的文本被视为长文本,将使用最小延迟并分块输出
- `STREAM_CHUNK_SIZE`: 长文本分块大小
- 默认值: `5`(字符)
- 说明: 长文本分块输出时,每个块的大小
### ▶️ 运行 ### ▶️ 运行

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.core.config import settings from app.core.config import settings
from app.core.logger import get_gemini_logger from app.core.logger import get_gemini_logger
from app.core.security import SecurityService from app.core.security import SecurityService
@@ -23,7 +23,7 @@ async def get_key_manager():
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)): async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
return await key_manager.get_next_working_key() return await key_manager.get_next_working_key()
model_service = ModelService(settings.MODEL_SEARCH) model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
@router.get("/models") @router.get("/models")
@@ -36,12 +36,40 @@ async def list_models(_=Depends(security_service.verify_key),
api_key = await key_manager.get_next_working_key() api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key) models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental", # 模型名称以及对应的详细信息
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767, model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, # 添加搜索模型
"topP": 0.95, "topK": 64, "maxTemperature": 2}) if settings.MODEL_SEARCH:
for name in settings.MODEL_SEARCH:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if settings.MODEL_IMAGE:
for name in settings.MODEL_IMAGE:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json return models_json
@@ -62,6 +90,9 @@ async def generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try: try:
response = await chat_service.generate_content( response = await chat_service.generate_content(
model=model_name, model=model_name,
@@ -92,6 +123,9 @@ async def stream_generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try: try:
response_stream = chat_service.stream_generate_content( response_stream = chat_service.stream_generate_content(
model=model_name, model=model_name,

View File

@@ -17,7 +17,7 @@ logger = get_openai_logger()
# 初始化服务 # 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.MODEL_SEARCH) model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE)
embedding_service = EmbeddingService(settings.BASE_URL) embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService() image_create_service = ImageCreateService()
@@ -61,6 +61,10 @@ async def chat_completion(
logger.info(f"Handling chat completion request for model: {request.model}") logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
try: try:
# 如果model是imagen3,使用paid_key # 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":

View File

@@ -7,6 +7,7 @@ class Settings(BaseSettings):
ALLOWED_TOKENS: List[str] ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta" BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"] MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
MODEL_IMAGE: List[str] = ["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED: bool = False TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True SHOW_THINKING_PROCESS: bool = True
@@ -16,7 +17,17 @@ class Settings(BaseSettings):
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002" CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
UPLOAD_PROVIDER: str = "smms" UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = "" SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
TEST_MODEL: str = "gemini-1.5-flash" TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置
STREAM_MIN_DELAY: float = 0.016
STREAM_MAX_DELAY: float = 0.024
STREAM_SHORT_TEXT_THRESHOLD: int = 10
STREAM_LONG_TEXT_THRESHOLD: int = 50
STREAM_CHUNK_SIZE: int = 5
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -149,6 +149,229 @@ class QiniuUploader(ImageUploader):
pass pass
class PicGoUploader(ImageUploader):
"""Chevereto API 图片上传器"""
def __init__(self, api_key: str, api_url: str = "https://www.picgo.net/api/1/upload"):
"""
初始化 Chevereto 上传器
Args:
api_key: Chevereto API 密钥
api_url: Chevereto API 上传地址
"""
self.api_key = api_key
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到 Chevereto 服务
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求头
headers = {
"X-API-Key": self.api_key
}
# 准备文件数据
files = {
"source": (filename, file)
}
# 发送请求
response = requests.post(
self.api_url,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if result.get("status_code") != 200:
error_message = "Upload failed"
if "error" in result:
error_message = result["error"].get("message", error_message)
raise UploadError(
message=error_message,
error_type=UploadErrorType.SERVER_ERROR,
status_code=result.get("status_code"),
details=result.get("error")
)
# 从响应中提取图片信息
image_data = result.get("image", {})
# 构建图片元数据
image_metadata = ImageMetadata(
width=image_data.get("width", 0),
height=image_data.get("height", 0),
filename=image_data.get("filename", filename),
size=image_data.get("size", 0),
url=image_data.get("url", ""),
delete_url=image_data.get("delete_url", None)
)
return UploadResponse(
success=True,
code="success",
message=result.get("success", {}).get("message", "Upload success"),
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class CloudFlareImgBedUploader(ImageUploader):
"""CloudFlare图床上传器"""
def __init__(self, auth_code: str, api_url: str):
"""
初始化CloudFlare图床上传器
Args:
auth_code: 认证码
api_url: 上传API地址
"""
self.auth_code = auth_code
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到CloudFlare图床
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求URL添加认证码参数如果存在
if self.auth_code:
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
else:
request_url = f"{self.api_url}?uploadNameType=origin"
# 准备文件数据
files = {
"file": (filename, file)
}
# 发送请求
response = requests.post(
request_url,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证响应格式
if not result or not isinstance(result, list) or len(result) == 0:
raise UploadError(
message="Invalid response format",
error_type=UploadErrorType.PARSE_ERROR
)
# 获取文件URL
file_path = result[0].get("src")
if not file_path:
raise UploadError(
message="Missing file URL in response",
error_type=UploadErrorType.PARSE_ERROR
)
# 构建完整URL如果返回的是相对路径
base_url = self.api_url.split("/upload")[0]
full_url = file_path if file_path.startswith(("http://", "https://")) else f"{base_url}{file_path}"
# 构建图片元数据注意CloudFlare-ImgBed不返回所有元数据所以部分字段为默认值
image_metadata = ImageMetadata(
width=0, # CloudFlare-ImgBed不返回宽度
height=0, # CloudFlare-ImgBed不返回高度
filename=filename,
size=0, # CloudFlare-ImgBed不返回大小
url=full_url,
delete_url=None # CloudFlare-ImgBed不返回删除URL
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError, IndexError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class ImageUploaderFactory: class ImageUploaderFactory:
@staticmethod @staticmethod
def create(provider: str, **credentials) -> ImageUploader: def create(provider: str, **credentials) -> ImageUploader:
@@ -159,5 +382,12 @@ class ImageUploaderFactory:
credentials["access_key"], credentials["access_key"],
credentials["secret_key"] credentials["secret_key"]
) )
elif provider == "picgo":
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
return PicGoUploader(credentials["api_key"], api_url)
elif provider == "cloudflare_imgbed":
return CloudFlareImgBedUploader(
credentials["auth_code"],
credentials["base_url"]
)
raise ValueError(f"Unknown provider: {provider}") raise ValueError(f"Unknown provider: {provider}")

View File

@@ -33,8 +33,8 @@ class GeminiContent(BaseModel):
class GeminiRequest(BaseModel): class GeminiRequest(BaseModel):
contents: List[GeminiContent] contents: List[GeminiContent] = []
tools: Optional[List[Dict[str, Any]]] = [] tools: Optional[List[Dict[str, Any]]] = []
safetySettings: Optional[List[SafetySetting]] = None safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = None generationConfig: Optional[GenerationConfig] = {}
systemInstruction: Optional[SystemInstruction] = None systemInstruction: Optional[SystemInstruction] = None

View File

@@ -24,10 +24,18 @@ class GeminiApiClient(ApiClient):
self.base_url = base_url self.base_url = base_url
self.timeout = timeout self.timeout = timeout
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: def _get_real_model(self, model: str) -> str:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"): if model.endswith("-search"):
model = model[:-7] model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
return model
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = await client.post(url, json=payload) response = await client.post(url, json=payload)
@@ -38,8 +46,8 @@ class GeminiApiClient(ApiClient):
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout) timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"): model = self._get_real_model(model)
model = model[:-7]
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response: async with client.stream(method="POST", url=url, json=payload) as response:

View File

@@ -1,9 +1,13 @@
# app/services/chat/message_converter.py # app/services/chat/message_converter.py
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests
import base64
SUPPORTED_ROLES = ["user", "model", "system"] SUPPORTED_ROLES = ["user", "model", "system"]
IMAGE_URL_PATTERN = r'\[image\]\((.*?)\)'
class MessageConverter(ABC): class MessageConverter(ABC):
@@ -13,13 +17,36 @@ class MessageConverter(ABC):
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass pass
def _get_mime_type_and_data(base64_string):
"""
从 base64 字符串中提取 MIME 类型和数据。
参数:
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
返回:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = r'data:([^;]+);base64,(.+)'
match = re.match(pattern, base64_string)
if match:
mime_type = match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def _convert_image(image_url: str) -> Dict[str, Any]: def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"): if image_url.startswith("data:image"):
mime_type, encoded_data = _get_mime_type_and_data(image_url)
return { return {
"inline_data": { "inline_data": {
"mime_type": "image/jpeg", "mime_type": mime_type,
"data": image_url.split(",")[1] "data": encoded_data
} }
} }
return { return {
@@ -29,12 +56,62 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
} }
def _convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
"""
处理可能包含图片URL的文本提取图片并转换为base64
Args:
text: 可能包含图片URL的文本
Returns:
List[Dict[str, Any]]: 包含文本和图片的部分列表
"""
parts = []
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(1)
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append({
"inlineData": {
"mimeType": "image/png",
"data": base64_data
}
})
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
else:
# 没有图片URL作为纯文本处理
parts.append({"text": text})
return parts
class OpenAIMessageConverter(MessageConverter): class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器""" """OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = [] converted_messages = []
system_instruction = None system_instruction_parts = []
for idx, msg in enumerate(messages): for idx, msg in enumerate(messages):
role = msg.get("role", "") role = msg.get("role", "")
@@ -49,9 +126,18 @@ class OpenAIMessageConverter(MessageConverter):
role = "model" role = "model"
parts = [] parts = []
if isinstance(msg["content"], str) and msg["content"]: # 特别处理最后一个assistant的消息按\n\n分割
if role == "assistant" and idx == len(messages) - 2 and isinstance(msg["content"], str) and msg["content"]:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除 # 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.append({"text": msg["content"]}) parts.extend(_process_text_with_image(msg["content"]))
elif isinstance(msg["content"], list): elif isinstance(msg["content"], list):
for content in msg["content"]: for content in msg["content"]:
if isinstance(content, str) and content: if isinstance(content, str) and content:
@@ -64,8 +150,16 @@ class OpenAIMessageConverter(MessageConverter):
if parts: if parts:
if role == "system": if role == "system":
system_instruction = {"role": "system", "parts": parts} system_instruction_parts.extend(parts)
else: else:
converted_messages.append({"role": role, "parts": parts}) converted_messages.append({"role": role, "parts": parts})
return converted_messages, system_instruction system_instruction = (
None
if not system_instruction_parts
else {
"role": "system",
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction

View File

@@ -1,5 +1,6 @@
# app/services/chat/response_handler.py # app/services/chat/response_handler.py
import base64
import json import json
import random import random
import string import string
@@ -8,6 +9,7 @@ from typing import Dict, Any, List, Optional
import time import time
import uuid import uuid
from app.core.config import settings from app.core.config import settings
from app.core.uploader import ImageUploaderFactory
class ResponseHandler(ABC): class ResponseHandler(ABC):
@@ -135,67 +137,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
candidate = response["candidates"][0] candidate = response["candidates"][0]
content = candidate.get("content", {}) content = candidate.get("content", {})
parts = content.get("parts", []) parts = content.get("parts", [])
# if "thinking" in model: if not parts:
# if settings.SHOW_THINKING_PROCESS: return "", []
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
if "text" in parts[0]: if "text" in parts[0]:
text = parts[0].get("text") text = parts[0].get("text")
elif "executableCode" in parts[0]: elif "executableCode" in parts[0]:
@@ -210,6 +153,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = _format_execution_result( text = _format_execution_result(
parts[0]["codeExecutionResult"] parts[0]["codeExecutionResult"]
) )
elif "inlineData" in parts[0]:
text = _extract_image_data(parts[0])
else: else:
text = "" text = ""
text = _add_search_link_text(model, candidate, text) text = _add_search_link_text(model, candidate, text)
@@ -235,14 +180,40 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False,
text = candidate["content"]["parts"][0]["text"] text = candidate["content"]["parts"][0]["text"]
else: else:
text = "" text = ""
for part in candidate["content"]["parts"]: if "parts" in candidate["content"]:
text += part.get("text", "") for part in candidate["content"]["parts"]:
if "text" in part:
text += part["text"]
elif "inlineData" in part:
text += _extract_image_data(part)
text = _add_search_link_text(model, candidate, text) text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format) tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
else: else:
text = "暂无返回" text = "暂无返回"
return text, tool_calls return text, tool_calls
def _extract_image_data(part: dict) -> str:
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
#将base64_data转成bytes数组
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename)
if upload_response.success:
text = f"![image]({upload_response.data.url})"
else:
text = ""
return text
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]: def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
"""提取工具调用信息""" """提取工具调用信息"""
if not parts or not isinstance(parts, list): if not parts or not isinstance(parts, list):

View File

@@ -4,6 +4,7 @@ import asyncio
import math import math
from typing import Any, List, AsyncGenerator, Callable from typing import Any, List, AsyncGenerator, Callable
from app.core.logger import get_openai_logger, get_gemini_logger from app.core.logger import get_openai_logger, get_gemini_logger
from app.core.config import settings
logger_openai = get_openai_logger() logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger() logger_gemini = get_gemini_logger()
@@ -20,8 +21,8 @@ class StreamOptimizer:
min_delay: float = 0.016, min_delay: float = 0.016,
max_delay: float = 0.024, max_delay: float = 0.024,
short_text_threshold: int = 10, short_text_threshold: int = 10,
long_text_threshold: int = 100, long_text_threshold: int = 50,
chunk_size: int = 10): chunk_size: int = 5):
"""初始化流式输出优化器 """初始化流式输出优化器
参数: 参数:
@@ -112,5 +113,20 @@ class StreamOptimizer:
# 创建默认的优化器实例,可以直接导入使用 # 创建默认的优化器实例,可以直接导入使用
openai_optimizer = StreamOptimizer(logger=logger_openai) openai_optimizer = StreamOptimizer(
gemini_optimizer = StreamOptimizer(logger=logger_gemini) logger=logger_openai,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)
gemini_optimizer = StreamOptimizer(
logger=logger_gemini,
min_delay=settings.STREAM_MIN_DELAY,
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
)

View File

@@ -62,14 +62,19 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload""" """构建请求payload"""
payload = request.model_dump() request_dict = request.model_dump()
return { payload = {
"contents": payload.get("contents", []), "contents": request_dict.get("contents", []),
"tools": _build_tools(model, payload), "tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model), "safetySettings": _get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}), "generationConfig": request_dict.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", []) "systemInstruction": request_dict.get("systemInstruction", "")
} }
if model.endswith("-image") or model.endswith("-image-generation"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
return payload
class GeminiChatService: class GeminiChatService:

View File

@@ -96,11 +96,6 @@ class ImageCreateService:
for index, generated_image in enumerate(response.generated_images): for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes image_data = generated_image.image.image_bytes
image_uploader = None image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
upload_response = image_uploader.upload(image_data,filename)
if request.response_format == "b64_json": if request.response_format == "b64_json":
base64_image = base64.b64encode(image_data).decode('utf-8') base64_image = base64.b64encode(image_data).decode('utf-8')
@@ -109,6 +104,30 @@ class ImageCreateService:
"revised_prompt": request.prompt "revised_prompt": request.prompt
}) })
else: else:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.SMMS_SECRET_TOKEN
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
)
else:
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
upload_response = image_uploader.upload(image_data, filename)
images_data.append({ images_data.append({
"url": f"{upload_response.data.url}", "url": f"{upload_response.data.url}",
"revised_prompt": request.prompt "revised_prompt": request.prompt

View File

@@ -7,8 +7,9 @@ from app.core.config import settings
logger = get_model_logger() logger = get_model_logger()
class ModelService: class ModelService:
def __init__(self, model_search: list): def __init__(self, model_search: list, model_image: list):
self.model_search = model_search self.model_search = model_search
self.model_image = model_image
self.base_url = "https://generativelanguage.googleapis.com/v1beta" self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]: def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
@@ -57,9 +58,27 @@ class ModelService:
search_model = openai_model.copy() search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search" search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model) openai_format["data"].append(search_model)
if model_id in self.model_image:
image_model = openai_model.copy()
image_model["id"] = f"{model_id}-image"
openai_format["data"].append(image_model)
if settings.CREATE_IMAGE_MODEL: if settings.CREATE_IMAGE_MODEL:
image_model = openai_model.copy() image_model = openai_model.copy()
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat" image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model) openai_format["data"].append(image_model)
return openai_format return openai_format
def check_model_support(self, model: str) -> bool:
if not model or not isinstance(model, str):
return False
model = model.strip()
if model.endswith("-search"):
model = model[:-7]
return model in settings.MODEL_SEARCH
if model.endswith("-image"):
model = model[:-6]
return model in settings.MODEL_IMAGE
return True

View File

@@ -35,7 +35,7 @@ def _build_tools(
if ( if (
settings.TOOLS_CODE_EXECUTION_ENABLED settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model) and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation"))
and not _has_image_parts(messages) and not _has_image_parts(messages)
): ):
tools.append({"code_execution": {}}) tools.append({"code_execution": {}})
@@ -110,12 +110,16 @@ def _build_payload(
"tools": _build_tools(request, messages), "tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model), "safetySettings": _get_safety_settings(request.model),
} }
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
if ( if (
instruction instruction
and isinstance(instruction, dict) and isinstance(instruction, dict)
and instruction.get("role") == "system" and instruction.get("role") == "system"
and instruction.get("parts") and instruction.get("parts")
and not request.model.endswith("-image")
and not request.model.endswith("-image-generation")
): ):
payload["systemInstruction"] = instruction payload["systemInstruction"] = instruction