feat: 添加图片生成功能及相关配置

- 添加图片生成相关配置和环境变量
- 新增图片上传服务和模型定义
- 扩展模型服务以支持图片生成模型
- 添加图片生成响应处理器
- 更新README文档以反映新功能
- 添加GitHub Actions发布工作流
This commit is contained in:
yinpeng
2025-02-11 01:59:16 +08:00
parent b3842b2329
commit a74ac03836
16 changed files with 497 additions and 28 deletions

15
.env.example Normal file
View File

@@ -0,0 +1,15 @@
API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"]
ALLOWED_TOKENS=["sk-123456"]
# AUTH_TOKEN=sk-123456
MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
TOOLS_CODE_EXECUTION_ENABLED=true
SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true
BASE_URL=https://generativelanguage.googleapis.com/v1beta
MAX_FAILURES=10
#########################image_generate 相关配置###########################
PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx
CREATE_IMAGE_MODEL=imagen-3.0-generate-002
UPLOAD_PROVIDER=smms
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
##########################################################################

41
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: Publish Release
on:
push:
tags:
- 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0
jobs:
release:
runs-on: ubuntu-latest
steps:
# Step 1: 检出代码库
- name: Checkout code
uses: actions/checkout@v3
# Step 2: 自动生成 Release
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref_name }}
release_name: ${{ github.ref_name }}
body: |
## Release Notes
- 自动发布版本。
- 请根据需求更新对应内容。
draft: false
prerelease: false
# Step 3: 可选,上传构建文件
- name: Upload Release Asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./your-build-file.zip # 替换为你的构建文件路径
asset_name: your-build-file.zip # 替换为你的文件名
asset_content_type: application/zip

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"commentTranslate.source": "upupnoah.chatgpt-comment-translateX-chatgpt"
}

View File

@@ -16,6 +16,7 @@
- **灵活配置**: 通过环境变量或 `.env` 文件轻松配置。
- **易于部署**: 提供 Docker 一键部署,也支持手动部署。
- **健康检查**: 提供健康检查接口,方便监控服务状态。
- **图片生成支持**: 支持使用OpenAI的DALL-E模型生成图片
## 🛠️ 技术栈
@@ -38,8 +39,8 @@
1. **克隆项目**:
```bash
git clone <your-repository-url>
cd <your-repository-name>
git clone https://github.com/snailyp/gemini-balance.git
cd gemini-balance
```
2. **安装依赖**:
@@ -71,7 +72,7 @@
- `TOOLS_CODE_EXECUTION_ENABLED`: 是否启用代码执行工具, 默认为 `false`。
- `SHOW_SEARCH_LINK`: 是否显示搜索结果链接(当使用搜索模型时)。
- `SHOW_THINKING_PROCESS`: 是否显示模型的"思考"过程(对于某些模型)。
- `AUTH_TOKEN`: 备用授权token, 如果不设置, 默认为 `ALLOWED_TOKENS` 的第一个。
- `AUTH_TOKEN`: 主鉴权token(权限较大,注意保管), 如果不设置, 默认为 `ALLOWED_TOKENS` 的第一个。
- `MAX_FAILURES`: 允许单个 API Key 失败的次数,超过此次数后该 Key 将被标记为无效。
### ▶️ 运行
@@ -106,7 +107,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
### 认证
所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer <your-token>`,其中 `<your-token>` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个。
所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer <your-token>`,其中 `<your-token>` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个或者 `AUTH_TOKEN`
### 获取模型列表
@@ -175,6 +176,22 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
- **Header**: `Authorization: Bearer <your-auth-token>`
- **说明**: 只有使用 `AUTH_TOKEN` 才能访问此接口, 用于获取有效和无效的 API Key 列表。
### 图片生成 (Image Generation)
- **URL**: `/v1/images/generations`
- **Method**: `POST`
- **Header**: `Authorization: Bearer <your-auth-token>`
- **说明**: Body示例和参数说明
```json
{
"model": "dall-e-3",
"prompt": "汉服美女",
"n": 1,
"size": "1024x1024"
}
```
## 📚 代码结构
```plaintext
@@ -190,16 +207,16 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
│ ├── middleware/ # 中间件
│ │ └── request_logging_middleware.py # 请求日志中间件
│ ├── schemas/ # 数据模型
│ │ ├── gemini_models.py # Gemini 请求/响应模型
│ │ └── openai_models.py # OpenAI 请求/响应模型
│ │ ├── gemini_models.py # Gemini 原始请求/响应模型
│ │ └── openai_models.py # OpenAI 兼容请求/响应模型
│ ├── services/ # 服务层
│ │ ├── chat/ # 聊天相关服务
│ │ │ ├── api_client.py # API 客户端
│ │ │ ├── message_converter.py # 消息转换器
│ │ │ ├── response_handler.py # 响应处理器
│ │ │ └── retry_handler.py #重试处理器
│ │ ├── gemini_chat_service.py # Gemini 聊天服务
│ │ ├── openai_chat_service.py # OpenAI 聊天服务
│ │ ├── gemini_chat_service.py # Gemini 原始聊天服务
│ │ ├── openai_chat_service.py # OpenAI 兼容聊天服务
│ │ ├── embedding_service.py # 向量服务
│ │ ├── key_manager.py # API Key 管理
│ │ └── model_service.py # 模型服务

View File

@@ -4,9 +4,10 @@ from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.core.logger import get_openai_logger
from app.core.security import SecurityService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
from app.services.chat.retry_handler import RetryHandler
from app.services.embedding_service import EmbeddingService
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.openai_chat_service import OpenAIChatService
@@ -19,6 +20,7 @@ security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
key_manager = KeyManager(settings.API_KEYS)
model_service = ModelService(settings.MODEL_SEARCH)
embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService()
@router.get("/v1/models")
@@ -43,16 +45,16 @@ async def chat_completion(
_=Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
):
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
api_key = await key_manager.get_paid_key()
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
logger.info("-" * 50 + "chat_completion" + "-" * 50)
logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
try:
response = await chat_service.create_chat_completion(
request=request,
api_key=api_key,
)
response = await chat_service.create_image_chat_completion(request=request)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
@@ -64,6 +66,25 @@ async def chat_completion(
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/v1/images/generations")
@router.post("/hf/v1/images/generations")
async def generate_image(
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
):
logger.info("-" * 50 + "generate_image" + "-" * 50)
logger.info(f"Handling image generation request for prompt: {request.prompt}")
try:
response = image_create_service.generate_images(request)
logger.info("Image generation request successful")
return response
except Exception as e:
logger.error(f"Image generation request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Image generation request failed") from e
@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(

View File

@@ -12,6 +12,10 @@ class Settings(BaseSettings):
SHOW_THINKING_PROCESS: bool = True
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = ""
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
def __init__(self):
super().__init__()

View File

@@ -129,3 +129,7 @@ def get_request_logger():
def get_retry_logger():
return Logger.setup_logger("retry")
def get_image_create_logger():
return Logger.setup_logger("image_create")

163
app/core/uploader.py Normal file
View File

@@ -0,0 +1,163 @@
import requests
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
from enum import Enum
from typing import Optional, Any
class UploadErrorType(Enum):
"""上传错误类型枚举"""
NETWORK_ERROR = "network_error" # 网络请求错误
AUTH_ERROR = "auth_error" # 认证错误
INVALID_FILE = "invalid_file" # 无效文件
SERVER_ERROR = "server_error" # 服务器错误
PARSE_ERROR = "parse_error" # 响应解析错误
UNKNOWN = "unknown" # 未知错误
class UploadError(Exception):
"""图片上传错误异常类"""
def __init__(
self,
message: str,
error_type: UploadErrorType = UploadErrorType.UNKNOWN,
status_code: Optional[int] = None,
details: Optional[dict] = None,
original_error: Optional[Exception] = None
):
"""
初始化上传错误异常
Args:
message: 错误消息
error_type: 错误类型
status_code: HTTP状态码
details: 详细错误信息
original_error: 原始异常
"""
self.message = message
self.error_type = error_type
self.status_code = status_code
self.details = details or {}
self.original_error = original_error
# 构建完整错误信息
full_message = f"[{error_type.value}] {message}"
if status_code:
full_message = f"{full_message} (Status: {status_code})"
if details:
full_message = f"{full_message} - Details: {details}"
super().__init__(full_message)
@classmethod
def from_response(cls, response: Any, message: Optional[str] = None) -> "UploadError":
"""
从HTTP响应创建错误实例
Args:
response: HTTP响应对象
message: 自定义错误消息
"""
try:
error_data = response.json()
details = error_data.get("data", {})
return cls(
message=message or error_data.get("message", "Unknown error"),
error_type=UploadErrorType.SERVER_ERROR,
status_code=response.status_code,
details=details
)
except Exception:
return cls(
message=message or "Failed to parse error response",
error_type=UploadErrorType.PARSE_ERROR,
status_code=response.status_code
)
class SmMsUploader(ImageUploader):
API_URL = "https://sm.ms/api/v2/upload"
def __init__(self, api_key: str):
self.api_key = api_key
def upload(self, file: bytes, filename: str) -> UploadResponse:
try:
# 准备请求头
headers = {
"Authorization": f"Basic {self.api_key}"
}
# 准备文件数据
files = {
"smfile": (filename, file, "image/png")
}
# 发送请求
response = requests.post(
self.API_URL,
headers=headers,
files=files
)
# 检查响应状态
response.raise_for_status()
# 解析响应
result = response.json()
# 验证上传是否成功
if not result.get("success"):
raise UploadError(result.get("message", "Upload failed"))
# 转换为统一格式
data = result["data"]
image_metadata = ImageMetadata(
width=data["width"],
height=data["height"],
filename=data["filename"],
size=data["size"],
url=data["url"],
delete_url=data["delete"]
)
return UploadResponse(
success=True,
code="success",
message="Upload success",
data=image_metadata
)
except requests.RequestException as e:
# 处理网络请求相关错误
raise UploadError(f"Upload request failed: {str(e)}")
except (KeyError, ValueError) as e:
# 处理响应解析错误
raise UploadError(f"Invalid response format: {str(e)}")
except Exception as e:
# 处理其他未预期的错误
raise UploadError(f"Upload failed: {str(e)}")
class QiniuUploader(ImageUploader):
def __init__(self, access_key: str, secret_key: str):
self.access_key = access_key
self.secret_key = secret_key
def upload(self, file: bytes, filename: str) -> UploadResponse:
# 实现七牛云的具体上传逻辑
pass
class ImageUploaderFactory:
@staticmethod
def create(provider: str, **credentials) -> ImageUploader:
if provider == "smms":
return SmMsUploader(credentials["api_key"])
elif provider == "qiniu":
return QiniuUploader(
credentials["access_key"],
credentials["secret_key"]
)
raise ValueError(f"Unknown provider: {provider}")

View File

@@ -0,0 +1,23 @@
class ImageMetadata:
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: str | None = None):
self.width = width
self.height = height
self.filename = filename
self.size = size
self.url = url
self.delete_url = delete_url
class UploadResponse:
def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
self.success = success
self.code = code
self.message = message
self.data = data
class ImageUploader:
def upload(self, file: bytes, filename: str) -> UploadResponse:
raise NotImplementedError

View File

@@ -18,3 +18,13 @@ class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-004"
encoding_format: Optional[str] = "float"
class ImageGenerationRequest(BaseModel):
model: str = "DALL-E-3"
prompt: str = ""
n: int = 1
size: Optional[str] = "1024x1024"
quality: Optional[str] = ""
style: Optional[str] = ""
response_format: Optional[str] = "b64_json"

View File

@@ -84,6 +84,47 @@ class OpenAIResponseHandler(ResponseHandler):
if stream:
return _handle_openai_stream_response(response, model, finish_reason)
return _handle_openai_normal_response(response, model, finish_reason)
def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"):
if stream:
return _handle_openai_stream_image_response(image_str,model,finish_reason)
return _handle_openai_normal_image_response(image_str,model,finish_reason)
def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": image_str} if image_str else {},
"finish_reason": finish_reason
}]
}
def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": image_str
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str:

View File

@@ -0,0 +1,81 @@
import time
import uuid
from google import genai
from google.genai import types
import base64
from app.core.config import settings
from app.core.logger import get_image_create_logger
from app.core.uploader import ImageUploaderFactory
from app.schemas.openai_models import ImageGenerationRequest
logger = get_image_create_logger()
class ImageCreateService:
def __init__(self, aspect_ratio="1:1"):
self.image_model = settings.CREATE_IMAGE_MODEL
self.paid_key = settings.PAID_KEY
self.aspect_ratio = aspect_ratio
def generate_images(self, request: ImageGenerationRequest):
client = genai.Client(api_key=self.paid_key)
if request.size == "1024x1024":
self.aspect_ratio = "1:1"
elif request.size == "1792x1024":
self.aspect_ratio = "16:9"
elif request.size == "1027x1792":
self.aspect_ratio = "9:16"
else:
raise ValueError(
f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
)
response = client.models.generate_images(
model=self.image_model,
prompt=request.prompt,
config=types.GenerateImagesConfig(
number_of_images=request.n,
output_mime_type="image/png",
aspect_ratio=self.aspect_ratio,
safety_filter_level="BLOCK_LOW_AND_ABOVE",
person_generation="ALLOW_ADULT",
# language="auto"
),
)
if response.generated_images:
images_data = []
for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
upload_response = image_uploader.upload(image_data,filename)
# base64_image = base64.b64encode(image_data).decode('utf-8')
images_data.append({
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt
})
response_data = {
"created": int(time.time()), # Current timestamp
"data": images_data
}
return response_data
else:
raise Exception("I can't generate these images")
def generate_images_chat(self, request: ImageGenerationRequest) -> str:
response = self.generate_images(request)
image_datas = response["data"]
if image_datas:
markdown_images = []
for index, image_data in enumerate(image_datas):
markdown_images.append(f"![Generated Image {index+1}]({image_data['url']})")
return "\n".join(markdown_images)

View File

@@ -15,7 +15,11 @@ class KeyManager:
self.failure_count_lock = asyncio.Lock()
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
self.MAX_FAILURES = settings.MAX_FAILURES
self.paid_key = settings.PAID_KEY
async def get_paid_key(self) -> str:
return self.paid_key
async def get_next_key(self) -> str:
"""获取下一个API key"""
async with self.key_cycle_lock:

View File

@@ -2,10 +2,10 @@ import requests
from datetime import datetime, timezone
from typing import Optional, Dict, Any
from app.core.logger import get_model_logger
from app.core.config import settings
logger = get_model_logger()
class ModelService:
def __init__(self, model_search: list):
self.model_search = model_search
@@ -52,6 +52,11 @@ class ModelService:
"parent": None,
}
openai_format["data"].append(openai_model)
if settings.CREATE_IMAGE_MODEL:
image_model = openai_model.copy()
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
openai_format["data"].append(image_model)
if model_id in self.model_search:
search_model = openai_model.copy()

View File

@@ -3,11 +3,11 @@
import json
from typing import Dict, Any, AsyncGenerator, List, Union
from app.core.logger import get_openai_logger
from app.services.chat.message_converter import OpenAIMessageConverter
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.api_client import GeminiApiClient
from app.schemas.openai_models import ChatRequest
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
from app.core.config import settings
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager
logger = get_openai_logger()
@@ -31,9 +31,9 @@ def _build_tools(
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(messages)
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(messages)
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
@@ -86,16 +86,17 @@ def _build_payload(
class OpenAIChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.message_converter = OpenAIMessageConverter()
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.image_create_service = ImageCreateService()
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
@@ -109,7 +110,7 @@ class OpenAIChatService:
return self._handle_normal_completion(request.model, payload, api_key)
def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = self.api_client.generate_content(payload, model, api_key)
@@ -118,7 +119,7 @@ class OpenAIChatService:
)
async def _handle_stream_completion(
self, model: str, payload: Dict[str, Any], api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
@@ -126,7 +127,7 @@ class OpenAIChatService:
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(
payload, model, api_key
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
@@ -154,3 +155,38 @@ class OpenAIChatService:
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"
break
async def create_image_chat_completion(
self,
request: ChatRequest,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
image_generate_request.prompt = request.messages[-1]["content"]
image_res = self.image_create_service.generate_images_chat(image_generate_request)
if request.stream:
return self._handle_stream_image_completion(request.model,image_res)
else:
return self._handle_normal_image_completion(request.model,image_res)
async def _handle_stream_image_completion(
self, model: str, image_data: str
) -> AsyncGenerator[str, None]:
if image_data:
openai_chunk = self.response_handler.handle_image_chat_response(
image_data, model, stream=True, finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Image chat streaming completed successfully")
def _handle_normal_image_completion(
self, model: str, image_data: str
) -> Dict[str, Any]:
return self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)

View File

@@ -5,4 +5,5 @@ pydantic
pydantic_settings
requests
starlette
uvicorn
uvicorn
google-genai