mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 14:21:27 +08:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89f2825ac7 | ||
|
|
985a12554d | ||
|
|
37a7a140fc | ||
|
|
28e67cc3fa | ||
|
|
d99a0bde93 | ||
|
|
cb5cd92041 | ||
|
|
0be85e9536 | ||
|
|
632dee38b3 | ||
|
|
16c28bf1ba |
@@ -14,6 +14,8 @@ CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
|||||||
UPLOAD_PROVIDER=smms
|
UPLOAD_PROVIDER=smms
|
||||||
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||||
PICGO_API_KEY=xxxx
|
PICGO_API_KEY=xxxx
|
||||||
|
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
||||||
|
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
||||||
##########################################################################
|
##########################################################################
|
||||||
#########################stream_optimizer 相关配置########################
|
#########################stream_optimizer 相关配置########################
|
||||||
STREAM_MIN_DELAY=0.016
|
STREAM_MIN_DELAY=0.016
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ COPY ./app /app/app
|
|||||||
ENV API_KEYS='["your_api_key_1"]'
|
ENV API_KEYS='["your_api_key_1"]'
|
||||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||||
ENV TOOLS_CODE_EXECUTION_ENABLED=fasle
|
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||||
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
|
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
|
||||||
|
|
||||||
# Expose port
|
# Expose port
|
||||||
|
|||||||
27
README.md
27
README.md
@@ -74,8 +74,11 @@
|
|||||||
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型,默认使用imagen-3.0
|
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型,默认使用imagen-3.0
|
||||||
|
|
||||||
# 图片上传配置
|
# 图片上传配置
|
||||||
UPLOAD_PROVIDER="smms" # 图片上传提供商,目前支持smms
|
UPLOAD_PROVIDER="smms" # 图片上传提供商,目前支持smms、picgo、cloudflare_imgbed
|
||||||
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
|
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
|
||||||
|
PICGO_API_KEY="your-picogo-apikey" # PicoGo图床的API Key 可在 `https://www.picgo.net/settings/api` 获取
|
||||||
|
CLOUDFLARE_IMGBED_URL="https://xxxxxxx.pages.dev/upload" # CloudFlare 图床上传地址,可自行搭建:`https://github.com/MarSeventh/CloudFlare-ImgBed`
|
||||||
|
CLOUDFLARE_IMGBED_AUTH_CODE="your-cloudflare-imgber-auth-code" # CloudFlare图床的鉴权key,可在项目后台设置,若无鉴权则可直接置空。
|
||||||
|
|
||||||
# stream_optimizer 相关配置
|
# stream_optimizer 相关配置
|
||||||
STREAM_MIN_DELAY=0.016
|
STREAM_MIN_DELAY=0.016
|
||||||
@@ -138,10 +141,26 @@
|
|||||||
|
|
||||||
- `UPLOAD_PROVIDER`: 图片上传服务提供商
|
- `UPLOAD_PROVIDER`: 图片上传服务提供商
|
||||||
- 默认值: `smms`
|
- 默认值: `smms`
|
||||||
- 说明: 目前支持 SM.MS 图床
|
- 可选值: `smms`, `picgo`, `cloudflare_imgbed`
|
||||||
|
- 说明: 用于选择图片上传的服务提供商。目前支持 SM.MS 图床, PicGo 图床, 以及 Cloudflare ImgBed。
|
||||||
|
|
||||||
- `SMMS_SECRET_TOKEN`: SM.MS API Token
|
- `SMMS_SECRET_TOKEN`: SM.MS API Token
|
||||||
- 用途: 用于图片上传到 SM.MS 图床
|
- 用途: 用于图片上传到 SM.MS 图床的身份验证。
|
||||||
- 获取方式: 需要在 SM.MS 官网注册并获取
|
- 获取方式: 需要在 [SM.MS 官网](https://sm.ms/) 注册并获取。
|
||||||
|
|
||||||
|
- `PICGO_API_KEY`: PicGo API Key
|
||||||
|
- 用途: 用于图片上传到 PicGo 图床的身份验证。
|
||||||
|
- 获取方式: 可在 [PicGo 官网](https://www.picgo.net/settings/api) 的设置页面 API 选项中获取。
|
||||||
|
|
||||||
|
- `CLOUDFLARE_IMGBED_URL`: Cloudflare ImgBed 上传地址
|
||||||
|
- 用途: 指定 Cloudflare ImgBed 图床的上传 API 地址。
|
||||||
|
- 获取方式: 如果您自行搭建了 Cloudflare ImgBed 服务,请填写您的服务部署地址。参考 [Cloudflare-ImgBed 项目](https://github.com/MarSeventh/CloudFlare-ImgBed) 自行搭建。
|
||||||
|
- 注意: URL 必须以 `https://` 开头,并指向 `/upload` 路径 ,例如 `https://cloudflare-imgbed-7b0.pages.dev/upload`。
|
||||||
|
|
||||||
|
- `CLOUDFLARE_IMGBED_AUTH_CODE`: Cloudflare ImgBed 鉴权 Key
|
||||||
|
- 用途: 用于 Cloudflare ImgBed 图床的身份验证。
|
||||||
|
- 说明: 如果您的 Cloudflare ImgBed 服务启用了鉴权,请填写鉴权 Key。若未启用鉴权,则留空即可。
|
||||||
|
- 获取方式: 在 Cloudflare ImgBed 项目的后台设置中获取,或在搭建时自行设置。
|
||||||
|
|
||||||
#### 流式输出优化配置
|
#### 流式输出优化配置
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from copy import deepcopy
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logger import get_gemini_logger
|
from app.core.logger import get_gemini_logger
|
||||||
from app.core.security import SecurityService
|
from app.core.security import SecurityService
|
||||||
@@ -36,18 +36,40 @@ async def list_models(_=Depends(security_service.verify_key),
|
|||||||
api_key = await key_manager.get_next_working_key()
|
api_key = await key_manager.get_next_working_key()
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
models_json = model_service.get_gemini_models(api_key)
|
models_json = model_service.get_gemini_models(api_key)
|
||||||
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
|
|
||||||
"displayName": "Gemini 2.0 Flash Search Experimental",
|
# 模型名称以及对应的详细信息
|
||||||
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
|
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
||||||
"outputTokenLimit": 8192,
|
|
||||||
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
|
# 添加搜索模型
|
||||||
"topP": 0.95, "topK": 64, "maxTemperature": 2})
|
if settings.MODEL_SEARCH:
|
||||||
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-image", "version": "2.0",
|
for name in settings.MODEL_SEARCH:
|
||||||
"displayName": "Gemini 2.0 Flash Image Experimental",
|
model = model_mapping.get(name, None)
|
||||||
"description": "Gemini 2.0 Flash Image Experimental", "inputTokenLimit": 32767,
|
if not model:
|
||||||
"outputTokenLimit": 8192,
|
continue
|
||||||
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
|
|
||||||
"topP": 0.95, "topK": 64, "maxTemperature": 2})
|
item = deepcopy(model)
|
||||||
|
item["name"] = f"models/{name}-search"
|
||||||
|
display_name = f'{item.get("displayName")} For Search'
|
||||||
|
item["displayName"] = display_name
|
||||||
|
item["description"] = display_name
|
||||||
|
|
||||||
|
models_json["models"].append(item)
|
||||||
|
|
||||||
|
# 添加图像生成模型
|
||||||
|
if settings.MODEL_IMAGE:
|
||||||
|
for name in settings.MODEL_IMAGE:
|
||||||
|
model = model_mapping.get(name, None)
|
||||||
|
if not model:
|
||||||
|
continue
|
||||||
|
|
||||||
|
item = deepcopy(model)
|
||||||
|
item["name"] = f"models/{name}-image"
|
||||||
|
display_name = f'{item.get("displayName")} For Image'
|
||||||
|
item["displayName"] = display_name
|
||||||
|
item["description"] = display_name
|
||||||
|
|
||||||
|
models_json["models"].append(item)
|
||||||
|
|
||||||
return models_json
|
return models_json
|
||||||
|
|
||||||
|
|
||||||
@@ -68,6 +90,9 @@ async def generate_content(
|
|||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
|
if not model_service.check_model_support(model_name):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await chat_service.generate_content(
|
response = await chat_service.generate_content(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@@ -98,6 +123,9 @@ async def stream_generate_content(
|
|||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
|
if not model_service.check_model_support(model_name):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_stream = chat_service.stream_generate_content(
|
response_stream = chat_service.stream_generate_content(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
|
|||||||
@@ -61,6 +61,10 @@ async def chat_completion(
|
|||||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
|
if not model_service.check_model_support(request.model):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 如果model是imagen3,使用paid_key
|
# 如果model是imagen3,使用paid_key
|
||||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ class Settings(BaseSettings):
|
|||||||
UPLOAD_PROVIDER: str = "smms"
|
UPLOAD_PROVIDER: str = "smms"
|
||||||
SMMS_SECRET_TOKEN: str = ""
|
SMMS_SECRET_TOKEN: str = ""
|
||||||
PICGO_API_KEY: str = ""
|
PICGO_API_KEY: str = ""
|
||||||
|
CLOUDFLARE_IMGBED_URL: str = ""
|
||||||
|
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||||
TEST_MODEL: str = "gemini-1.5-flash"
|
TEST_MODEL: str = "gemini-1.5-flash"
|
||||||
|
|
||||||
# 流式输出优化器配置
|
# 流式输出优化器配置
|
||||||
|
|||||||
@@ -258,6 +258,119 @@ class PicGoUploader(ImageUploader):
|
|||||||
original_error=e
|
original_error=e
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudFlareImgBedUploader(ImageUploader):
|
||||||
|
"""CloudFlare图床上传器"""
|
||||||
|
|
||||||
|
def __init__(self, auth_code: str, api_url: str):
|
||||||
|
"""
|
||||||
|
初始化CloudFlare图床上传器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_code: 认证码
|
||||||
|
api_url: 上传API地址
|
||||||
|
"""
|
||||||
|
self.auth_code = auth_code
|
||||||
|
self.api_url = api_url
|
||||||
|
|
||||||
|
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||||
|
"""
|
||||||
|
上传图片到CloudFlare图床
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 图片文件二进制数据
|
||||||
|
filename: 文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UploadResponse: 上传响应对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UploadError: 上传失败时抛出异常
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 准备请求URL(添加认证码参数,如果存在)
|
||||||
|
if self.auth_code:
|
||||||
|
request_url = f"{self.api_url}?authCode={self.auth_code}"
|
||||||
|
else:
|
||||||
|
request_url = self.api_url
|
||||||
|
|
||||||
|
# 准备文件数据
|
||||||
|
files = {
|
||||||
|
"file": (filename, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
response = requests.post(
|
||||||
|
request_url,
|
||||||
|
files=files
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查响应状态
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# 验证响应格式
|
||||||
|
if not result or not isinstance(result, list) or len(result) == 0:
|
||||||
|
raise UploadError(
|
||||||
|
message="Invalid response format",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取文件URL
|
||||||
|
file_path = result[0].get("src")
|
||||||
|
if not file_path:
|
||||||
|
raise UploadError(
|
||||||
|
message="Missing file URL in response",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建完整URL(如果返回的是相对路径)
|
||||||
|
base_url = self.api_url.split("/upload")[0]
|
||||||
|
full_url = file_path if file_path.startswith(("http://", "https://")) else f"{base_url}{file_path}"
|
||||||
|
|
||||||
|
# 构建图片元数据(注意:CloudFlare-ImgBed不返回所有元数据,所以部分字段为默认值)
|
||||||
|
image_metadata = ImageMetadata(
|
||||||
|
width=0, # CloudFlare-ImgBed不返回宽度
|
||||||
|
height=0, # CloudFlare-ImgBed不返回高度
|
||||||
|
filename=filename,
|
||||||
|
size=0, # CloudFlare-ImgBed不返回大小
|
||||||
|
url=full_url,
|
||||||
|
delete_url=None # CloudFlare-ImgBed不返回删除URL
|
||||||
|
)
|
||||||
|
|
||||||
|
return UploadResponse(
|
||||||
|
success=True,
|
||||||
|
code="success",
|
||||||
|
message="Upload success",
|
||||||
|
data=image_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
# 处理网络请求相关错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Upload request failed: {str(e)}",
|
||||||
|
error_type=UploadErrorType.NETWORK_ERROR,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
except (KeyError, ValueError, TypeError, IndexError) as e:
|
||||||
|
# 处理响应解析错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Invalid response format: {str(e)}",
|
||||||
|
error_type=UploadErrorType.PARSE_ERROR,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
except UploadError:
|
||||||
|
# 重新抛出已经是 UploadError 类型的异常
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其他未预期的错误
|
||||||
|
raise UploadError(
|
||||||
|
message=f"Upload failed: {str(e)}",
|
||||||
|
error_type=UploadErrorType.UNKNOWN,
|
||||||
|
original_error=e
|
||||||
|
)
|
||||||
|
|
||||||
class ImageUploaderFactory:
|
class ImageUploaderFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -272,4 +385,9 @@ class ImageUploaderFactory:
|
|||||||
elif provider == "picgo":
|
elif provider == "picgo":
|
||||||
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
|
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
|
||||||
return PicGoUploader(credentials["api_key"], api_url)
|
return PicGoUploader(credentials["api_key"], api_url)
|
||||||
|
elif provider == "cloudflare_imgbed":
|
||||||
|
return CloudFlareImgBedUploader(
|
||||||
|
credentials["auth_code"],
|
||||||
|
credentials["base_url"]
|
||||||
|
)
|
||||||
raise ValueError(f"Unknown provider: {provider}")
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|||||||
@@ -24,12 +24,18 @@ class GeminiApiClient(ApiClient):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
def _get_real_model(self, model: str) -> str:
|
||||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
|
||||||
if model.endswith("-search"):
|
if model.endswith("-search"):
|
||||||
model = model[:-7]
|
model = model[:-7]
|
||||||
if model.endswith("-image"):
|
if model.endswith("-image"):
|
||||||
model = model[:-6]
|
model = model[:-6]
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||||
|
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||||
|
model = self._get_real_model(model)
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||||
response = await client.post(url, json=payload)
|
response = await client.post(url, json=payload)
|
||||||
@@ -40,10 +46,8 @@ class GeminiApiClient(ApiClient):
|
|||||||
|
|
||||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||||
if model.endswith("-search"):
|
model = self._get_real_model(model)
|
||||||
model = model[:-7]
|
|
||||||
if model.endswith("-image"):
|
|
||||||
model = model[:-6]
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
# app/services/chat/message_converter.py
|
# app/services/chat/message_converter.py
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
|
||||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||||
|
IMAGE_URL_PATTERN = r'\[image\]\((.*?)\)'
|
||||||
|
|
||||||
|
|
||||||
class MessageConverter(ABC):
|
class MessageConverter(ABC):
|
||||||
@@ -13,13 +17,36 @@ class MessageConverter(ABC):
|
|||||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _get_mime_type_and_data(base64_string):
|
||||||
|
"""
|
||||||
|
从 base64 字符串中提取 MIME 类型和数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: (mime_type, encoded_data)
|
||||||
|
"""
|
||||||
|
# 检查字符串是否以 "data:" 格式开始
|
||||||
|
if base64_string.startswith('data:'):
|
||||||
|
# 提取 MIME 类型和数据
|
||||||
|
pattern = r'data:([^;]+);base64,(.+)'
|
||||||
|
match = re.match(pattern, base64_string)
|
||||||
|
if match:
|
||||||
|
mime_type = match.group(1)
|
||||||
|
encoded_data = match.group(2)
|
||||||
|
return mime_type, encoded_data
|
||||||
|
|
||||||
|
# 如果不是预期格式,假定它只是数据部分
|
||||||
|
return None, base64_string
|
||||||
|
|
||||||
def _convert_image(image_url: str) -> Dict[str, Any]:
|
def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||||
if image_url.startswith("data:image"):
|
if image_url.startswith("data:image"):
|
||||||
|
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
||||||
return {
|
return {
|
||||||
"inline_data": {
|
"inline_data": {
|
||||||
"mime_type": "image/jpeg",
|
"mime_type": mime_type,
|
||||||
"data": image_url.split(",")[1]
|
"data": encoded_data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
@@ -29,12 +56,62 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_image_to_base64(url: str) -> str:
|
||||||
|
"""
|
||||||
|
将图片URL转换为base64编码
|
||||||
|
Args:
|
||||||
|
url: 图片URL
|
||||||
|
Returns:
|
||||||
|
str: base64编码的图片数据
|
||||||
|
"""
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# 将图片内容转换为base64
|
||||||
|
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||||
|
return img_data
|
||||||
|
else:
|
||||||
|
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
处理可能包含图片URL的文本,提取图片并转换为base64
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 可能包含图片URL的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
||||||
|
if img_url_match:
|
||||||
|
# 提取URL
|
||||||
|
img_url = img_url_match.group(1)
|
||||||
|
# 将URL对应的图片转换为base64
|
||||||
|
try:
|
||||||
|
base64_data = _convert_image_to_base64(img_url)
|
||||||
|
parts.append({
|
||||||
|
"inlineData": {
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"data": base64_data
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
# 如果转换失败,回退到文本模式
|
||||||
|
parts.append({"text": text})
|
||||||
|
else:
|
||||||
|
# 没有图片URL,作为纯文本处理
|
||||||
|
parts.append({"text": text})
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMessageConverter(MessageConverter):
|
class OpenAIMessageConverter(MessageConverter):
|
||||||
"""OpenAI消息格式转换器"""
|
"""OpenAI消息格式转换器"""
|
||||||
|
|
||||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||||
converted_messages = []
|
converted_messages = []
|
||||||
system_instruction = None
|
system_instruction_parts = []
|
||||||
|
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
role = msg.get("role", "")
|
role = msg.get("role", "")
|
||||||
@@ -49,9 +126,18 @@ class OpenAIMessageConverter(MessageConverter):
|
|||||||
role = "model"
|
role = "model"
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
if isinstance(msg["content"], str) and msg["content"]:
|
# 特别处理最后一个assistant的消息,按\n\n分割
|
||||||
|
if role == "assistant" and idx == len(messages) - 2 and isinstance(msg["content"], str) and msg["content"]:
|
||||||
|
# 按\n\n分割消息
|
||||||
|
content_parts = msg["content"].split("\n\n")
|
||||||
|
for part in content_parts:
|
||||||
|
if not part.strip(): # 跳过空内容
|
||||||
|
continue
|
||||||
|
# 处理可能包含图片的文本
|
||||||
|
parts.extend(_process_text_with_image(part))
|
||||||
|
elif isinstance(msg["content"], str) and msg["content"]:
|
||||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
||||||
parts.append({"text": msg["content"]})
|
parts.extend(_process_text_with_image(msg["content"]))
|
||||||
elif isinstance(msg["content"], list):
|
elif isinstance(msg["content"], list):
|
||||||
for content in msg["content"]:
|
for content in msg["content"]:
|
||||||
if isinstance(content, str) and content:
|
if isinstance(content, str) and content:
|
||||||
@@ -64,8 +150,16 @@ class OpenAIMessageConverter(MessageConverter):
|
|||||||
|
|
||||||
if parts:
|
if parts:
|
||||||
if role == "system":
|
if role == "system":
|
||||||
system_instruction = {"role": "system", "parts": parts}
|
system_instruction_parts.extend(parts)
|
||||||
else:
|
else:
|
||||||
converted_messages.append({"role": role, "parts": parts})
|
converted_messages.append({"role": role, "parts": parts})
|
||||||
|
|
||||||
return converted_messages, system_instruction
|
system_instruction = (
|
||||||
|
None
|
||||||
|
if not system_instruction_parts
|
||||||
|
else {
|
||||||
|
"role": "system",
|
||||||
|
"parts": system_instruction_parts,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return converted_messages, system_instruction
|
||||||
@@ -200,6 +200,8 @@ def _extract_image_data(part: dict) -> str:
|
|||||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
||||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
|
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
|
||||||
|
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||||
|
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
|
||||||
current_date = time.strftime("%Y/%m/%d")
|
current_date = time.strftime("%Y/%m/%d")
|
||||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||||
base64_data = part["inlineData"]["data"]
|
base64_data = part["inlineData"]["data"]
|
||||||
@@ -207,7 +209,7 @@ def _extract_image_data(part: dict) -> str:
|
|||||||
bytes_data = base64.b64decode(base64_data)
|
bytes_data = base64.b64decode(base64_data)
|
||||||
upload_response = image_uploader.upload(bytes_data,filename)
|
upload_response = image_uploader.upload(bytes_data,filename)
|
||||||
if upload_response.success:
|
if upload_response.success:
|
||||||
text = f"\n\n"
|
text = f""
|
||||||
else:
|
else:
|
||||||
text = ""
|
text = ""
|
||||||
return text
|
return text
|
||||||
|
|||||||
@@ -96,11 +96,6 @@ class ImageCreateService:
|
|||||||
for index, generated_image in enumerate(response.generated_images):
|
for index, generated_image in enumerate(response.generated_images):
|
||||||
image_data = generated_image.image.image_bytes
|
image_data = generated_image.image.image_bytes
|
||||||
image_uploader = None
|
image_uploader = None
|
||||||
if settings.UPLOAD_PROVIDER == "smms":
|
|
||||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
|
||||||
current_date = time.strftime("%Y/%m/%d")
|
|
||||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
|
||||||
upload_response = image_uploader.upload(image_data,filename)
|
|
||||||
|
|
||||||
if request.response_format == "b64_json":
|
if request.response_format == "b64_json":
|
||||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||||
@@ -109,6 +104,30 @@ class ImageCreateService:
|
|||||||
"revised_prompt": request.prompt
|
"revised_prompt": request.prompt
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
|
current_date = time.strftime("%Y/%m/%d")
|
||||||
|
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||||
|
|
||||||
|
if settings.UPLOAD_PROVIDER == "smms":
|
||||||
|
image_uploader = ImageUploaderFactory.create(
|
||||||
|
provider=settings.UPLOAD_PROVIDER,
|
||||||
|
api_key=settings.SMMS_SECRET_TOKEN
|
||||||
|
)
|
||||||
|
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||||
|
image_uploader = ImageUploaderFactory.create(
|
||||||
|
provider=settings.UPLOAD_PROVIDER,
|
||||||
|
api_key=settings.PICGO_API_KEY
|
||||||
|
)
|
||||||
|
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||||
|
image_uploader = ImageUploaderFactory.create(
|
||||||
|
provider=settings.UPLOAD_PROVIDER,
|
||||||
|
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||||
|
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
|
||||||
|
|
||||||
|
upload_response = image_uploader.upload(image_data, filename)
|
||||||
|
|
||||||
images_data.append({
|
images_data.append({
|
||||||
"url": f"{upload_response.data.url}",
|
"url": f"{upload_response.data.url}",
|
||||||
"revised_prompt": request.prompt
|
"revised_prompt": request.prompt
|
||||||
|
|||||||
@@ -68,3 +68,17 @@ class ModelService:
|
|||||||
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||||
openai_format["data"].append(image_model)
|
openai_format["data"].append(image_model)
|
||||||
return openai_format
|
return openai_format
|
||||||
|
|
||||||
|
def check_model_support(self, model: str) -> bool:
|
||||||
|
if not model or not isinstance(model, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
model = model.strip()
|
||||||
|
if model.endswith("-search"):
|
||||||
|
model = model[:-7]
|
||||||
|
return model in settings.MODEL_SEARCH
|
||||||
|
if model.endswith("-image"):
|
||||||
|
model = model[:-6]
|
||||||
|
return model in settings.MODEL_IMAGE
|
||||||
|
|
||||||
|
return True
|
||||||
Reference in New Issue
Block a user