Compare commits

...

9 Commits

Author SHA1 Message Date
cr-zhichen
89f2825ac7 feat: 新增对CloudFlare ImgBed的支持,更新环境变量和文档 2025-03-16 04:39:40 +00:00
snaily
985a12554d fix:修改OpenAI消息转换器中assistant消息处理逻辑,将特殊处理的目标从最后一条消息调整为倒数第二条消息。 2025-03-15 21:18:20 +08:00
snaily
37a7a140fc feat:改进消息转换器中的图像处理和消息分割逻辑
添加 _get_mime_type_and_data 函数从 base64 字符串中提取 MIME 类型和数据
修改 _convert_image 函数使用动态检测的 MIME 类型,而非硬编码
将 _process_text_with_image 中的 MIME 类型从 image/jpeg 改为 image/png
简化异常处理逻辑
优化 OpenAIMessageConverter 中的消息分割逻辑,仅对最后一个 assistant 消息进行分割处理
2025-03-15 21:11:10 +08:00
zhanghaoyu
28e67cc3fa 1. modify IMAGE_URL_PATTERN
2. modify import
2025-03-15 12:37:56 +08:00
zhanghaoyu7
d99a0bde93 feat: 新增图文上下文同步 2025-03-14 16:29:03 +08:00
snaily
cb5cd92041 fix: 修正Dockerfile中TOOLS_CODE_EXECUTION_ENABLED环境变量的拼写错误
将TOOLS_CODE_EXECUTION_ENABLED环境变量的值从"fasle"更正为"false",修复了拼写错误。
2025-03-14 13:46:31 +08:00
snaily
0be85e9536 feat(gemini_routes): 添加deepcopy导入
在gemini_routes.py中添加了Python标准库copy模块中的deepcopy函数导入,用于创建对象的深度副本,确保数据操作过程中不会意外修改原始对象。
2025-03-14 13:43:17 +08:00
Toddy
632dee38b3 check model before send request 2025-03-14 04:11:21 +00:00
Toddy
16c28bf1ba combine multiple system instructions into one 2025-03-14 02:55:29 +00:00
12 changed files with 343 additions and 37 deletions

View File

@@ -14,6 +14,8 @@ CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx PICGO_API_KEY=xxxx
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
########################################################################## ##########################################################################
#########################stream_optimizer 相关配置######################## #########################stream_optimizer 相关配置########################
STREAM_MIN_DELAY=0.016 STREAM_MIN_DELAY=0.016

View File

@@ -10,7 +10,7 @@ COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]' ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]' ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=fasle ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]' ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
# Expose port # Expose port

View File

@@ -74,8 +74,11 @@
CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型默认使用imagen-3.0 CREATE_IMAGE_MODEL="imagen-3.0-generate-002" # 图片生成模型默认使用imagen-3.0
# 图片上传配置 # 图片上传配置
UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms UPLOAD_PROVIDER="smms" # 图片上传提供商目前支持smms、picgo、cloudflare_imgbed
SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token SMMS_SECRET_TOKEN="your-smms-token" # SM.MS图床的API Token
PICGO_API_KEY="your-picogo-apikey" # PicoGo图床的API Key 可在 `https://www.picgo.net/settings/api` 获取
CLOUDFLARE_IMGBED_URL="https://xxxxxxx.pages.dev/upload" # CloudFlare 图床上传地址,可自行搭建:`https://github.com/MarSeventh/CloudFlare-ImgBed`
CLOUDFLARE_IMGBED_AUTH_CODE="your-cloudflare-imgber-auth-code" # CloudFlare图床的鉴权key可在项目后台设置若无鉴权则可直接置空。
# stream_optimizer 相关配置 # stream_optimizer 相关配置
STREAM_MIN_DELAY=0.016 STREAM_MIN_DELAY=0.016
@@ -138,10 +141,26 @@
- `UPLOAD_PROVIDER`: 图片上传服务提供商 - `UPLOAD_PROVIDER`: 图片上传服务提供商
- 默认值: `smms` - 默认值: `smms`
- 说明: 目前支持 SM.MS 图床 - 可选值: `smms`, `picgo`, `cloudflare_imgbed`
- 说明: 用于选择图片上传的服务提供商。目前支持 SM.MS 图床, PicGo 图床, 以及 Cloudflare ImgBed。
- `SMMS_SECRET_TOKEN`: SM.MS API Token - `SMMS_SECRET_TOKEN`: SM.MS API Token
- 用途: 用于图片上传到 SM.MS 图床 - 用途: 用于图片上传到 SM.MS 图床的身份验证。
- 获取方式: 需要在 SM.MS 官网注册并获取 - 获取方式: 需要在 [SM.MS 官网](https://sm.ms/) 注册并获取
- `PICGO_API_KEY`: PicGo API Key
- 用途: 用于图片上传到 PicGo 图床的身份验证。
- 获取方式: 可在 [PicGo 官网](https://www.picgo.net/settings/api) 的设置页面 API 选项中获取。
- `CLOUDFLARE_IMGBED_URL`: Cloudflare ImgBed 上传地址
- 用途: 指定 Cloudflare ImgBed 图床的上传 API 地址。
- 获取方式: 如果您自行搭建了 Cloudflare ImgBed 服务,请填写您的服务部署地址。参考 [Cloudflare-ImgBed 项目](https://github.com/MarSeventh/CloudFlare-ImgBed) 自行搭建。
- 注意: URL 必须以 `https://` 开头,并指向 `/upload` 路径 ,例如 `https://cloudflare-imgbed-7b0.pages.dev/upload`。
- `CLOUDFLARE_IMGBED_AUTH_CODE`: Cloudflare ImgBed 鉴权 Key
- 用途: 用于 Cloudflare ImgBed 图床的身份验证。
- 说明: 如果您的 Cloudflare ImgBed 服务启用了鉴权,请填写鉴权 Key。若未启用鉴权则留空即可。
- 获取方式: 在 Cloudflare ImgBed 项目的后台设置中获取,或在搭建时自行设置。
#### 流式输出优化配置 #### 流式输出优化配置

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.core.config import settings from app.core.config import settings
from app.core.logger import get_gemini_logger from app.core.logger import get_gemini_logger
from app.core.security import SecurityService from app.core.security import SecurityService
@@ -36,18 +36,40 @@ async def list_models(_=Depends(security_service.verify_key),
api_key = await key_manager.get_next_working_key() api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key) models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental", # 模型名称以及对应的详细信息
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767, model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, # 添加搜索模型
"topP": 0.95, "topK": 64, "maxTemperature": 2}) if settings.MODEL_SEARCH:
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-image", "version": "2.0", for name in settings.MODEL_SEARCH:
"displayName": "Gemini 2.0 Flash Image Experimental", model = model_mapping.get(name, None)
"description": "Gemini 2.0 Flash Image Experimental", "inputTokenLimit": 32767, if not model:
"outputTokenLimit": 8192, continue
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
"topP": 0.95, "topK": 64, "maxTemperature": 2}) item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if settings.MODEL_IMAGE:
for name in settings.MODEL_IMAGE:
model = model_mapping.get(name, None)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json return models_json
@@ -68,6 +90,9 @@ async def generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try: try:
response = await chat_service.generate_content( response = await chat_service.generate_content(
model=model_name, model=model_name,
@@ -98,6 +123,9 @@ async def stream_generate_content(
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try: try:
response_stream = chat_service.stream_generate_content( response_stream = chat_service.stream_generate_content(
model=model_name, model=model_name,

View File

@@ -61,6 +61,10 @@ async def chat_completion(
logger.info(f"Handling chat completion request for model: {request.model}") logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}") logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
try: try:
# 如果model是imagen3,使用paid_key # 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":

View File

@@ -18,6 +18,8 @@ class Settings(BaseSettings):
UPLOAD_PROVIDER: str = "smms" UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = "" SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = "" PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
TEST_MODEL: str = "gemini-1.5-flash" TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置 # 流式输出优化器配置

View File

@@ -258,6 +258,119 @@ class PicGoUploader(ImageUploader):
original_error=e original_error=e
) )
class CloudFlareImgBedUploader(ImageUploader):
"""CloudFlare图床上传器"""
def __init__(self, auth_code: str, api_url: str):
"""
初始化CloudFlare图床上传器
Args:
auth_code: 认证码
api_url: 上传API地址
"""
self.auth_code = auth_code
self.api_url = api_url
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到CloudFlare图床
Args:
file: 图片文件二进制数据
filename: 文件名
Returns:
UploadResponse: 上传响应对象
Raises:
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求URL添加认证码参数如果存在
if self.auth_code:
request_url = f"{self.api_url}?authCode={self.auth_code}"
else:
request_url = self.api_url
# 准备文件数据
files = {
"file": (filename, file)
}
# 发送请求
response = requests.post(
request_url,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证响应格式
if not result or not isinstance(result, list) or len(result) == 0:
raise UploadError(
message="Invalid response format",
error_type=UploadErrorType.PARSE_ERROR
)
# 获取文件URL
file_path = result[0].get("src")
if not file_path:
raise UploadError(
message="Missing file URL in response",
error_type=UploadErrorType.PARSE_ERROR
)
# 构建完整URL如果返回的是相对路径
base_url = self.api_url.split("/upload")[0]
full_url = file_path if file_path.startswith(("http://", "https://")) else f"{base_url}{file_path}"
# 构建图片元数据注意CloudFlare-ImgBed不返回所有元数据所以部分字段为默认值
image_metadata = ImageMetadata(
width=0, # CloudFlare-ImgBed不返回宽度
height=0, # CloudFlare-ImgBed不返回高度
filename=filename,
size=0, # CloudFlare-ImgBed不返回大小
url=full_url,
delete_url=None # CloudFlare-ImgBed不返回删除URL
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(
message=f"Upload request failed: {str(e)}",
error_type=UploadErrorType.NETWORK_ERROR,
original_error=e
)
except (KeyError, ValueError, TypeError, IndexError) as e:
# 处理响应解析错误
raise UploadError(
message=f"Invalid response format: {str(e)}",
error_type=UploadErrorType.PARSE_ERROR,
original_error=e
)
except UploadError:
# 重新抛出已经是 UploadError 类型的异常
raise
except Exception as e:
# 处理其他未预期的错误
raise UploadError(
message=f"Upload failed: {str(e)}",
error_type=UploadErrorType.UNKNOWN,
original_error=e
)
class ImageUploaderFactory: class ImageUploaderFactory:
@staticmethod @staticmethod
@@ -272,4 +385,9 @@ class ImageUploaderFactory:
elif provider == "picgo": elif provider == "picgo":
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload") api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
return PicGoUploader(credentials["api_key"], api_url) return PicGoUploader(credentials["api_key"], api_url)
elif provider == "cloudflare_imgbed":
return CloudFlareImgBedUploader(
credentials["auth_code"],
credentials["base_url"]
)
raise ValueError(f"Unknown provider: {provider}") raise ValueError(f"Unknown provider: {provider}")

View File

@@ -24,12 +24,18 @@ class GeminiApiClient(ApiClient):
self.base_url = base_url self.base_url = base_url
self.timeout = timeout self.timeout = timeout
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: def _get_real_model(self, model: str) -> str:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"): if model.endswith("-search"):
model = model[:-7] model = model[:-7]
if model.endswith("-image"): if model.endswith("-image"):
model = model[:-6] model = model[:-6]
return model
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = await client.post(url, json=payload) response = await client.post(url, json=payload)
@@ -40,10 +46,8 @@ class GeminiApiClient(ApiClient):
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout) timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"): model = self._get_real_model(model)
model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response: async with client.stream(method="POST", url=url, json=payload) as response:

View File

@@ -1,9 +1,13 @@
# app/services/chat/message_converter.py # app/services/chat/message_converter.py
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests
import base64
SUPPORTED_ROLES = ["user", "model", "system"] SUPPORTED_ROLES = ["user", "model", "system"]
IMAGE_URL_PATTERN = r'\[image\]\((.*?)\)'
class MessageConverter(ABC): class MessageConverter(ABC):
@@ -13,13 +17,36 @@ class MessageConverter(ABC):
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass pass
def _get_mime_type_and_data(base64_string):
"""
从 base64 字符串中提取 MIME 类型和数据。
参数:
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
返回:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = r'data:([^;]+);base64,(.+)'
match = re.match(pattern, base64_string)
if match:
mime_type = match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def _convert_image(image_url: str) -> Dict[str, Any]: def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"): if image_url.startswith("data:image"):
mime_type, encoded_data = _get_mime_type_and_data(image_url)
return { return {
"inline_data": { "inline_data": {
"mime_type": "image/jpeg", "mime_type": mime_type,
"data": image_url.split(",")[1] "data": encoded_data
} }
} }
return { return {
@@ -29,12 +56,62 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
} }
def _convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
"""
处理可能包含图片URL的文本提取图片并转换为base64
Args:
text: 可能包含图片URL的文本
Returns:
List[Dict[str, Any]]: 包含文本和图片的部分列表
"""
parts = []
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(1)
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append({
"inlineData": {
"mimeType": "image/png",
"data": base64_data
}
})
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
else:
# 没有图片URL作为纯文本处理
parts.append({"text": text})
return parts
class OpenAIMessageConverter(MessageConverter): class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器""" """OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = [] converted_messages = []
system_instruction = None system_instruction_parts = []
for idx, msg in enumerate(messages): for idx, msg in enumerate(messages):
role = msg.get("role", "") role = msg.get("role", "")
@@ -49,9 +126,18 @@ class OpenAIMessageConverter(MessageConverter):
role = "model" role = "model"
parts = [] parts = []
if isinstance(msg["content"], str) and msg["content"]: # 特别处理最后一个assistant的消息按\n\n分割
if role == "assistant" and idx == len(messages) - 2 and isinstance(msg["content"], str) and msg["content"]:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除 # 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.append({"text": msg["content"]}) parts.extend(_process_text_with_image(msg["content"]))
elif isinstance(msg["content"], list): elif isinstance(msg["content"], list):
for content in msg["content"]: for content in msg["content"]:
if isinstance(content, str) and content: if isinstance(content, str) and content:
@@ -64,8 +150,16 @@ class OpenAIMessageConverter(MessageConverter):
if parts: if parts:
if role == "system": if role == "system":
system_instruction = {"role": "system", "parts": parts} system_instruction_parts.extend(parts)
else: else:
converted_messages.append({"role": role, "parts": parts}) converted_messages.append({"role": role, "parts": parts})
return converted_messages, system_instruction system_instruction = (
None
if not system_instruction_parts
else {
"role": "system",
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction

View File

@@ -200,6 +200,8 @@ def _extract_image_data(part: dict) -> str:
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN) image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
elif settings.UPLOAD_PROVIDER == "picgo": elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY) image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,base_url=settings.CLOUDFLARE_IMGBED_URL,auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE)
current_date = time.strftime("%Y/%m/%d") current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"] base64_data = part["inlineData"]["data"]
@@ -207,7 +209,7 @@ def _extract_image_data(part: dict) -> str:
bytes_data = base64.b64decode(base64_data) bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data,filename) upload_response = image_uploader.upload(bytes_data,filename)
if upload_response.success: if upload_response.success:
text = f"\n![image]({upload_response.data.url})\n" text = f"![image]({upload_response.data.url})"
else: else:
text = "" text = ""
return text return text

View File

@@ -96,11 +96,6 @@ class ImageCreateService:
for index, generated_image in enumerate(response.generated_images): for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes image_data = generated_image.image.image_bytes
image_uploader = None image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
upload_response = image_uploader.upload(image_data,filename)
if request.response_format == "b64_json": if request.response_format == "b64_json":
base64_image = base64.b64encode(image_data).decode('utf-8') base64_image = base64.b64encode(image_data).decode('utf-8')
@@ -109,6 +104,30 @@ class ImageCreateService:
"revised_prompt": request.prompt "revised_prompt": request.prompt
}) })
else: else:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.SMMS_SECRET_TOKEN
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
)
else:
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
upload_response = image_uploader.upload(image_data, filename)
images_data.append({ images_data.append({
"url": f"{upload_response.data.url}", "url": f"{upload_response.data.url}",
"revised_prompt": request.prompt "revised_prompt": request.prompt

View File

@@ -68,3 +68,17 @@ class ModelService:
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat" image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model) openai_format["data"].append(image_model)
return openai_format return openai_format
def check_model_support(self, model: str) -> bool:
if not model or not isinstance(model, str):
return False
model = model.strip()
if model.endswith("-search"):
model = model[:-7]
return model in settings.MODEL_SEARCH
if model.endswith("-image"):
model = model[:-6]
return model in settings.MODEL_IMAGE
return True