mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-28 03:02:12 +08:00
feat: 添加图片生成功能及相关配置
- 添加图片生成相关配置和环境变量 - 新增图片上传服务和模型定义 - 扩展模型服务以支持图片生成模型 - 添加图片生成响应处理器 - 更新README文档以反映新功能 - 添加GitHub Actions发布工作流
This commit is contained in:
15
.env.example
Normal file
15
.env.example
Normal file
@@ -0,0 +1,15 @@
|
||||
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
|
||||
ALLOWED_TOKENS=["sk-123456"]
|
||||
# AUTH_TOKEN=sk-123456
|
||||
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED=true
|
||||
SHOW_SEARCH_LINK=true
|
||||
SHOW_THINKING_PROCESS=true
|
||||
BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
MAX_FAILURES=10
|
||||
#########################image_generate 相关配置###########################
|
||||
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
|
||||
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
||||
UPLOAD_PROVIDER=smms
|
||||
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||
##########################################################################
|
||||
41
.github/workflows/release.yml
vendored
Normal file
41
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: Publish Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Step 1: 检出代码库
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Step 2: 自动生成 Release
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ github.ref_name }}
|
||||
release_name: ${{ github.ref_name }}
|
||||
body: |
|
||||
## Release Notes
|
||||
- 自动发布版本。
|
||||
- 请根据需求更新对应内容。
|
||||
draft: false
|
||||
prerelease: false
|
||||
|
||||
# Step 3: 可选,上传构建文件
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
asset_path: ./your-build-file.zip # 替换为你的构建文件路径
|
||||
asset_name: your-build-file.zip # 替换为你的文件名
|
||||
asset_content_type: application/zip
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"commentTranslate.source": "upupnoah.chatgpt-comment-translateX-chatgpt"
|
||||
}
|
||||
33
README.md
33
README.md
@@ -16,6 +16,7 @@
|
||||
- **灵活配置**: 通过环境变量或 `.env` 文件轻松配置。
|
||||
- **易于部署**: 提供 Docker 一键部署,也支持手动部署。
|
||||
- **健康检查**: 提供健康检查接口,方便监控服务状态。
|
||||
- **图片生成支持**: 支持使用OpenAI的DALL-E模型生成图片
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
@@ -38,8 +39,8 @@
|
||||
1. **克隆项目**:
|
||||
|
||||
```bash
|
||||
git clone <your-repository-url>
|
||||
cd <your-repository-name>
|
||||
git clone https://github.com/snailyp/gemini-balance.git
|
||||
cd gemini-balance
|
||||
```
|
||||
|
||||
2. **安装依赖**:
|
||||
@@ -71,7 +72,7 @@
|
||||
- `TOOLS_CODE_EXECUTION_ENABLED`: 是否启用代码执行工具, 默认为 `false`。
|
||||
- `SHOW_SEARCH_LINK`: 是否显示搜索结果链接(当使用搜索模型时)。
|
||||
- `SHOW_THINKING_PROCESS`: 是否显示模型的"思考"过程(对于某些模型)。
|
||||
- `AUTH_TOKEN`: 备用授权token, 如果不设置, 默认为 `ALLOWED_TOKENS` 的第一个。
|
||||
- `AUTH_TOKEN`: 主鉴权token(权限较大,注意保管), 如果不设置, 默认为 `ALLOWED_TOKENS` 的第一个。
|
||||
- `MAX_FAILURES`: 允许单个 API Key 失败的次数,超过此次数后该 Key 将被标记为无效。
|
||||
|
||||
### ▶️ 运行
|
||||
@@ -106,7 +107,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
### 认证
|
||||
|
||||
所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer <your-token>`,其中 `<your-token>` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个。
|
||||
所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer <your-token>`,其中 `<your-token>` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个或者 `AUTH_TOKEN`。
|
||||
|
||||
### 获取模型列表
|
||||
|
||||
@@ -175,6 +176,22 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- **Header**: `Authorization: Bearer <your-auth-token>`
|
||||
- **说明**: 只有使用 `AUTH_TOKEN` 才能访问此接口, 用于获取有效和无效的 API Key 列表。
|
||||
|
||||
### 图片生成 (Image Generation)
|
||||
|
||||
- **URL**: `/v1/images/generations`
|
||||
- **Method**: `POST`
|
||||
- **Header**: `Authorization: Bearer <your-auth-token>`
|
||||
- **说明**: Body示例和参数说明
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "dall-e-3",
|
||||
"prompt": "汉服美女",
|
||||
"n": 1,
|
||||
"size": "1024x1024"
|
||||
}
|
||||
```
|
||||
|
||||
## 📚 代码结构
|
||||
|
||||
```plaintext
|
||||
@@ -190,16 +207,16 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
│ ├── middleware/ # 中间件
|
||||
│ │ └── request_logging_middleware.py # 请求日志中间件
|
||||
│ ├── schemas/ # 数据模型
|
||||
│ │ ├── gemini_models.py # Gemini 请求/响应模型
|
||||
│ │ └── openai_models.py # OpenAI 请求/响应模型
|
||||
│ │ ├── gemini_models.py # Gemini 原始请求/响应模型
|
||||
│ │ └── openai_models.py # OpenAI 兼容请求/响应模型
|
||||
│ ├── services/ # 服务层
|
||||
│ │ ├── chat/ # 聊天相关服务
|
||||
│ │ │ ├── api_client.py # API 客户端
|
||||
│ │ │ ├── message_converter.py # 消息转换器
|
||||
│ │ │ ├── response_handler.py # 响应处理器
|
||||
│ │ │ └── retry_handler.py #重试处理器
|
||||
│ │ ├── gemini_chat_service.py # Gemini 聊天服务
|
||||
│ │ ├── openai_chat_service.py # OpenAI 聊天服务
|
||||
│ │ ├── gemini_chat_service.py # Gemini 原始聊天服务
|
||||
│ │ ├── openai_chat_service.py # OpenAI 兼容聊天服务
|
||||
│ │ ├── embedding_service.py # 向量服务
|
||||
│ │ ├── key_manager.py # API Key 管理
|
||||
│ │ └── model_service.py # 模型服务
|
||||
|
||||
@@ -4,9 +4,10 @@ from fastapi.responses import StreamingResponse
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_openai_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
|
||||
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
|
||||
from app.services.chat.retry_handler import RetryHandler
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
from app.services.image_create_service import ImageCreateService
|
||||
from app.services.key_manager import KeyManager
|
||||
from app.services.model_service import ModelService
|
||||
from app.services.openai_chat_service import OpenAIChatService
|
||||
@@ -19,6 +20,7 @@ security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||
key_manager = KeyManager(settings.API_KEYS)
|
||||
model_service = ModelService(settings.MODEL_SEARCH)
|
||||
embedding_service = EmbeddingService(settings.BASE_URL)
|
||||
image_create_service = ImageCreateService()
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
@@ -43,16 +45,16 @@ async def chat_completion(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(key_manager.get_next_working_key),
|
||||
):
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
api_key = await key_manager.get_paid_key()
|
||||
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
|
||||
logger.info("-" * 50 + "chat_completion" + "-" * 50)
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
response = await chat_service.create_chat_completion(
|
||||
request=request,
|
||||
api_key=api_key,
|
||||
)
|
||||
response = await chat_service.create_image_chat_completion(request=request)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
@@ -64,6 +66,25 @@ async def chat_completion(
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@router.post("/hf/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
):
|
||||
logger.info("-" * 50 + "generate_image" + "-" * 50)
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
|
||||
try:
|
||||
response = image_create_service.generate_images(request)
|
||||
logger.info("Image generation request successful")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Image generation request failed") from e
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@router.post("/hf/v1/embeddings")
|
||||
async def embedding(
|
||||
|
||||
@@ -12,6 +12,10 @@ class Settings(BaseSettings):
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
AUTH_TOKEN: str = ""
|
||||
MAX_FAILURES: int = 3
|
||||
PAID_KEY: str = ""
|
||||
CREATE_IMAGE_MODEL: str = ""
|
||||
UPLOAD_PROVIDER: str = "smms"
|
||||
SMMS_SECRET_TOKEN: str = ""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -129,3 +129,7 @@ def get_request_logger():
|
||||
|
||||
def get_retry_logger():
|
||||
return Logger.setup_logger("retry")
|
||||
|
||||
|
||||
def get_image_create_logger():
|
||||
return Logger.setup_logger("image_create")
|
||||
|
||||
163
app/core/uploader.py
Normal file
163
app/core/uploader.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import requests
|
||||
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
|
||||
class UploadErrorType(Enum):
|
||||
"""上传错误类型枚举"""
|
||||
NETWORK_ERROR = "network_error" # 网络请求错误
|
||||
AUTH_ERROR = "auth_error" # 认证错误
|
||||
INVALID_FILE = "invalid_file" # 无效文件
|
||||
SERVER_ERROR = "server_error" # 服务器错误
|
||||
PARSE_ERROR = "parse_error" # 响应解析错误
|
||||
UNKNOWN = "unknown" # 未知错误
|
||||
|
||||
|
||||
class UploadError(Exception):
|
||||
"""图片上传错误异常类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_type: UploadErrorType = UploadErrorType.UNKNOWN,
|
||||
status_code: Optional[int] = None,
|
||||
details: Optional[dict] = None,
|
||||
original_error: Optional[Exception] = None
|
||||
):
|
||||
"""
|
||||
初始化上传错误异常
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
error_type: 错误类型
|
||||
status_code: HTTP状态码
|
||||
details: 详细错误信息
|
||||
original_error: 原始异常
|
||||
"""
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
self.details = details or {}
|
||||
self.original_error = original_error
|
||||
|
||||
# 构建完整错误信息
|
||||
full_message = f"[{error_type.value}] {message}"
|
||||
if status_code:
|
||||
full_message = f"{full_message} (Status: {status_code})"
|
||||
if details:
|
||||
full_message = f"{full_message} - Details: {details}"
|
||||
|
||||
super().__init__(full_message)
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, response: Any, message: Optional[str] = None) -> "UploadError":
|
||||
"""
|
||||
从HTTP响应创建错误实例
|
||||
|
||||
Args:
|
||||
response: HTTP响应对象
|
||||
message: 自定义错误消息
|
||||
"""
|
||||
try:
|
||||
error_data = response.json()
|
||||
details = error_data.get("data", {})
|
||||
return cls(
|
||||
message=message or error_data.get("message", "Unknown error"),
|
||||
error_type=UploadErrorType.SERVER_ERROR,
|
||||
status_code=response.status_code,
|
||||
details=details
|
||||
)
|
||||
except Exception:
|
||||
return cls(
|
||||
message=message or "Failed to parse error response",
|
||||
error_type=UploadErrorType.PARSE_ERROR,
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
|
||||
class SmMsUploader(ImageUploader):
|
||||
API_URL = "https://sm.ms/api/v2/upload"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
try:
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"Authorization": f"Basic {self.api_key}"
|
||||
}
|
||||
|
||||
# 准备文件数据
|
||||
files = {
|
||||
"smfile": (filename, file, "image/png")
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
self.API_URL,
|
||||
headers=headers,
|
||||
files=files
|
||||
)
|
||||
|
||||
# 检查响应状态
|
||||
response.raise_for_status()
|
||||
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
# 验证上传是否成功
|
||||
if not result.get("success"):
|
||||
raise UploadError(result.get("message", "Upload failed"))
|
||||
|
||||
# 转换为统一格式
|
||||
data = result["data"]
|
||||
image_metadata = ImageMetadata(
|
||||
width=data["width"],
|
||||
height=data["height"],
|
||||
filename=data["filename"],
|
||||
size=data["size"],
|
||||
url=data["url"],
|
||||
delete_url=data["delete"]
|
||||
)
|
||||
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
code="success",
|
||||
message="Upload success",
|
||||
data=image_metadata
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
# 处理网络请求相关错误
|
||||
raise UploadError(f"Upload request failed: {str(e)}")
|
||||
except (KeyError, ValueError) as e:
|
||||
# 处理响应解析错误
|
||||
raise UploadError(f"Invalid response format: {str(e)}")
|
||||
except Exception as e:
|
||||
# 处理其他未预期的错误
|
||||
raise UploadError(f"Upload failed: {str(e)}")
|
||||
|
||||
|
||||
class QiniuUploader(ImageUploader):
|
||||
def __init__(self, access_key: str, secret_key: str):
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
# 实现七牛云的具体上传逻辑
|
||||
pass
|
||||
|
||||
|
||||
class ImageUploaderFactory:
|
||||
@staticmethod
|
||||
def create(provider: str, **credentials) -> ImageUploader:
|
||||
if provider == "smms":
|
||||
return SmMsUploader(credentials["api_key"])
|
||||
elif provider == "qiniu":
|
||||
return QiniuUploader(
|
||||
credentials["access_key"],
|
||||
credentials["secret_key"]
|
||||
)
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
23
app/schemas/image_models.py
Normal file
23
app/schemas/image_models.py
Normal file
@@ -0,0 +1,23 @@
|
||||
class ImageMetadata:
|
||||
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: str | None = None):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
self.url = url
|
||||
self.delete_url = delete_url
|
||||
|
||||
|
||||
class UploadResponse:
|
||||
def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
|
||||
self.success = success
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.data = data
|
||||
|
||||
|
||||
class ImageUploader:
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -18,3 +18,13 @@ class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-004"
|
||||
encoding_format: Optional[str] = "float"
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
model: str = "DALL-E-3"
|
||||
prompt: str = ""
|
||||
n: int = 1
|
||||
size: Optional[str] = "1024x1024"
|
||||
quality: Optional[str] = ""
|
||||
style: Optional[str] = ""
|
||||
response_format: Optional[str] = "b64_json"
|
||||
|
||||
@@ -84,6 +84,47 @@ class OpenAIResponseHandler(ResponseHandler):
|
||||
if stream:
|
||||
return _handle_openai_stream_response(response, model, finish_reason)
|
||||
return _handle_openai_normal_response(response, model, finish_reason)
|
||||
|
||||
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
|
||||
if stream:
|
||||
return _handle_openai_stream_image_response(image_str,model,finish_reason)
|
||||
return _handle_openai_normal_image_response(image_str,model,finish_reason)
|
||||
|
||||
|
||||
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": image_str} if image_str else {},
|
||||
"finish_reason": finish_reason
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": image_str
|
||||
},
|
||||
"finish_reason": finish_reason
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str:
|
||||
|
||||
81
app/services/image_create_service.py
Normal file
81
app/services/image_create_service.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
import base64
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_image_create_logger
|
||||
from app.core.uploader import ImageUploaderFactory
|
||||
from app.schemas.openai_models import ImageGenerationRequest
|
||||
|
||||
logger = get_image_create_logger()
|
||||
|
||||
|
||||
class ImageCreateService:
|
||||
def __init__(self, aspect_ratio="1:1"):
|
||||
self.image_model = settings.CREATE_IMAGE_MODEL
|
||||
self.paid_key = settings.PAID_KEY
|
||||
self.aspect_ratio = aspect_ratio
|
||||
|
||||
def generate_images(self, request: ImageGenerationRequest):
|
||||
client = genai.Client(api_key=self.paid_key)
|
||||
if request.size == "1024x1024":
|
||||
self.aspect_ratio = "1:1"
|
||||
elif request.size == "1792x1024":
|
||||
self.aspect_ratio = "16:9"
|
||||
elif request.size == "1027x1792":
|
||||
self.aspect_ratio = "9:16"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
|
||||
)
|
||||
|
||||
response = client.models.generate_images(
|
||||
model=self.image_model,
|
||||
prompt=request.prompt,
|
||||
config=types.GenerateImagesConfig(
|
||||
number_of_images=request.n,
|
||||
output_mime_type="image/png",
|
||||
aspect_ratio=self.aspect_ratio,
|
||||
safety_filter_level="BLOCK_LOW_AND_ABOVE",
|
||||
person_generation="ALLOW_ADULT",
|
||||
# language="auto"
|
||||
),
|
||||
)
|
||||
|
||||
if response.generated_images:
|
||||
images_data = []
|
||||
for index, generated_image in enumerate(response.generated_images):
|
||||
image_data = generated_image.image.image_bytes
|
||||
image_uploader = None
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
upload_response = image_uploader.upload(image_data,filename)
|
||||
|
||||
# base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||
images_data.append({
|
||||
"url": f"{upload_response.data.url}",
|
||||
"revised_prompt": request.prompt
|
||||
})
|
||||
|
||||
response_data = {
|
||||
"created": int(time.time()), # Current timestamp
|
||||
"data": images_data
|
||||
}
|
||||
return response_data
|
||||
else:
|
||||
raise Exception("I can't generate these images")
|
||||
|
||||
def generate_images_chat(self, request: ImageGenerationRequest) -> str:
|
||||
response = self.generate_images(request)
|
||||
image_datas = response["data"]
|
||||
if image_datas:
|
||||
markdown_images = []
|
||||
for index, image_data in enumerate(image_datas):
|
||||
markdown_images.append(f"")
|
||||
return "\n".join(markdown_images)
|
||||
|
||||
@@ -15,7 +15,11 @@ class KeyManager:
|
||||
self.failure_count_lock = asyncio.Lock()
|
||||
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
|
||||
self.MAX_FAILURES = settings.MAX_FAILURES
|
||||
self.paid_key = settings.PAID_KEY
|
||||
|
||||
async def get_paid_key(self) -> str:
|
||||
return self.paid_key
|
||||
|
||||
async def get_next_key(self) -> str:
|
||||
"""获取下一个API key"""
|
||||
async with self.key_cycle_lock:
|
||||
|
||||
@@ -2,10 +2,10 @@ import requests
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
from app.core.logger import get_model_logger
|
||||
from app.core.config import settings
|
||||
|
||||
logger = get_model_logger()
|
||||
|
||||
|
||||
class ModelService:
|
||||
def __init__(self, model_search: list):
|
||||
self.model_search = model_search
|
||||
@@ -52,6 +52,11 @@ class ModelService:
|
||||
"parent": None,
|
||||
}
|
||||
openai_format["data"].append(openai_model)
|
||||
|
||||
if settings.CREATE_IMAGE_MODEL:
|
||||
image_model = openai_model.copy()
|
||||
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
openai_format["data"].append(image_model)
|
||||
|
||||
if model_id in self.model_search:
|
||||
search_model = openai_model.copy()
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
import json
|
||||
from typing import Dict, Any, AsyncGenerator, List, Union
|
||||
from app.core.logger import get_openai_logger
|
||||
from app.services.chat.message_converter import OpenAIMessageConverter
|
||||
from app.services.chat.response_handler import OpenAIResponseHandler
|
||||
from app.services.chat.api_client import GeminiApiClient
|
||||
from app.schemas.openai_models import ChatRequest
|
||||
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.core.config import settings
|
||||
from app.services.image_create_service import ImageCreateService
|
||||
from app.services.key_manager import KeyManager
|
||||
|
||||
logger = get_openai_logger()
|
||||
@@ -31,9 +31,9 @@ def _build_tools(
|
||||
model = request.model
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(messages)
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(messages)
|
||||
):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
@@ -86,16 +86,17 @@ def _build_payload(
|
||||
class OpenAIChatService:
|
||||
"""聊天服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager):
|
||||
self.message_converter = OpenAIMessageConverter()
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
|
||||
self.response_handler = OpenAIResponseHandler(config=None)
|
||||
self.api_client = GeminiApiClient(base_url)
|
||||
self.key_manager = key_manager
|
||||
self.image_create_service = ImageCreateService()
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
# 转换消息格式
|
||||
@@ -109,7 +110,7 @@ class OpenAIChatService:
|
||||
return self._handle_normal_completion(request.model, payload, api_key)
|
||||
|
||||
def _handle_normal_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理普通聊天完成"""
|
||||
response = self.api_client.generate_content(payload, model, api_key)
|
||||
@@ -118,7 +119,7 @@ class OpenAIChatService:
|
||||
)
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑"""
|
||||
retries = 0
|
||||
@@ -126,7 +127,7 @@ class OpenAIChatService:
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, api_key
|
||||
payload, model, api_key
|
||||
):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
@@ -154,3 +155,38 @@ class OpenAIChatService:
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
|
||||
image_generate_request = ImageGenerationRequest()
|
||||
image_generate_request.prompt = request.messages[-1]["content"]
|
||||
image_res = self.image_create_service.generate_images_chat(image_generate_request)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_image_completion(request.model,image_res)
|
||||
else:
|
||||
return self._handle_normal_image_completion(request.model,image_res)
|
||||
|
||||
async def _handle_stream_image_completion(
|
||||
self, model: str, image_data: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if image_data:
|
||||
openai_chunk = self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=True, finish_reason=None
|
||||
)
|
||||
if openai_chunk:
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("Image chat streaming completed successfully")
|
||||
|
||||
def _handle_normal_image_completion(
|
||||
self, model: str, image_data: str
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
return self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
@@ -5,4 +5,5 @@ pydantic
|
||||
pydantic_settings
|
||||
requests
|
||||
starlette
|
||||
uvicorn
|
||||
uvicorn
|
||||
google-genai
|
||||
Reference in New Issue
Block a user