diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..3660838 --- /dev/null +++ b/.env.example @@ -0,0 +1,15 @@ +API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"] +ALLOWED_TOKENS=["sk-123456"] +# AUTH_TOKEN=sk-123456 +MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"] +TOOLS_CODE_EXECUTION_ENABLED=true +SHOW_SEARCH_LINK=true +SHOW_THINKING_PROCESS=true +BASE_URL=https://generativelanguage.googleapis.com/v1beta +MAX_FAILURES=10 +#########################image_generate 相关配置########################### +PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx +CREATE_IMAGE_MODEL=imagen-3.0-generate-002 +UPLOAD_PROVIDER=smms +SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX +########################################################################## diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..899b574 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,41 @@ +name: Publish Release + +on: + push: + tags: + - 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0) + +jobs: + release: + runs-on: ubuntu-latest + steps: + # Step 1: 检出代码库 + - name: Checkout code + uses: actions/checkout@v3 + + # Step 2: 自动生成 Release + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.ref_name }} + release_name: ${{ github.ref_name }} + body: | + ## Release Notes + - 自动发布版本。 + - 请根据需求更新对应内容。 + draft: false + prerelease: false + + # Step 3: 可选,上传构建文件 + - name: Upload Release Asset + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: ./your-build-file.zip # 替换为你的构建文件路径 + asset_name: your-build-file.zip # 替换为你的文件名 + asset_content_type: application/zip \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..6314584 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "commentTranslate.source": "upupnoah.chatgpt-comment-translateX-chatgpt" +} \ No newline at end of file diff --git a/README.md b/README.md index 75f5c8a..c1e5fdf 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ - **灵活配置**: 通过环境变量或 `.env` 文件轻松配置。 - **易于部署**: 提供 Docker 一键部署,也支持手动部署。 - **健康检查**: 提供健康检查接口,方便监控服务状态。 +- **图片生成支持**: 支持使用OpenAI的DALL-E模型生成图片 ## 🛠️ 技术栈 @@ -38,8 +39,8 @@ 1. **克隆项目**: ```bash - git clone - cd + git clone https://github.com/snailyp/gemini-balance.git + cd gemini-balance ``` 2. **安装依赖**: @@ -71,7 +72,7 @@ - `TOOLS_CODE_EXECUTION_ENABLED`: 是否启用代码执行工具, 默认为 `false`。 - `SHOW_SEARCH_LINK`: 是否显示搜索结果链接(当使用搜索模型时)。 - `SHOW_THINKING_PROCESS`: 是否显示模型的"思考"过程(对于某些模型)。 - - `AUTH_TOKEN`: 备用授权token, 如果不设置, 默认为 `ALLOWED_TOKENS` 的第一个。 + - `AUTH_TOKEN`: 主鉴权token(权限较大,注意保管), 如果不设置, 默认为 `ALLOWED_TOKENS` 的第一个。 - `MAX_FAILURES`: 允许单个 API Key 失败的次数,超过此次数后该 Key 将被标记为无效。 ### ▶️ 运行 @@ -106,7 +107,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload ### 认证 -所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer `,其中 `` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个。 +所有 API 请求都需要在 Header 中添加 `Authorization` 字段,值为 `Bearer `,其中 `` 需要替换为你在 `.env` 文件中配置的 `ALLOWED_TOKENS` 中的一个或者 `AUTH_TOKEN`。 ### 获取模型列表 @@ -175,6 +176,22 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload - **Header**: `Authorization: Bearer ` - **说明**: 只有使用 `AUTH_TOKEN` 才能访问此接口, 用于获取有效和无效的 API Key 列表。 +### 图片生成 (Image Generation) + +- **URL**: `/v1/images/generations` +- **Method**: `POST` +- **Header**: `Authorization: Bearer ` +- **说明**: Body示例和参数说明 + + ```json + { + "model": "dall-e-3", + "prompt": "汉服美女", + "n": 1, + "size": "1024x1024" + } + ``` + ## 📚 代码结构 ```plaintext @@ -190,16 +207,16 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload │ ├── middleware/ # 中间件 │ │ └── request_logging_middleware.py # 请求日志中间件 │ ├── schemas/ # 数据模型 -│ │ ├── gemini_models.py # Gemini 请求/响应模型 -│ │ └── openai_models.py # OpenAI 请求/响应模型 +│ │ ├── gemini_models.py # Gemini 原始请求/响应模型 +│ │ └── openai_models.py # OpenAI 兼容请求/响应模型 │ ├── services/ # 服务层 │ │ ├── chat/ # 聊天相关服务 │ │ │ ├── api_client.py # API 客户端 │ │ │ ├── message_converter.py # 消息转换器 │ │ │ ├── response_handler.py # 响应处理器 │ │ │ └── retry_handler.py #重试处理器 -│ │ ├── gemini_chat_service.py # Gemini 聊天服务 -│ │ ├── openai_chat_service.py # OpenAI 聊天服务 +│ │ ├── gemini_chat_service.py # Gemini 原始聊天服务 +│ │ ├── openai_chat_service.py # OpenAI 兼容聊天服务 │ │ ├── embedding_service.py # 向量服务 │ │ ├── key_manager.py # API Key 管理 │ │ └── model_service.py # 模型服务 diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index d3d6a2c..9812c16 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -4,9 +4,10 @@ from fastapi.responses import StreamingResponse from app.core.config import settings from app.core.logger import get_openai_logger from app.core.security import SecurityService -from app.schemas.openai_models import ChatRequest, EmbeddingRequest +from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest from app.services.chat.retry_handler import RetryHandler from app.services.embedding_service import EmbeddingService +from app.services.image_create_service import ImageCreateService from app.services.key_manager import KeyManager from app.services.model_service import ModelService from app.services.openai_chat_service import OpenAIChatService @@ -19,6 +20,7 @@ security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) key_manager = KeyManager(settings.API_KEYS) model_service = ModelService(settings.MODEL_SEARCH) embedding_service = EmbeddingService(settings.BASE_URL) +image_create_service = ImageCreateService() @router.get("/v1/models") @@ -43,16 +45,16 @@ async def chat_completion( _=Depends(security_service.verify_authorization), api_key: str = Depends(key_manager.get_next_working_key), ): + # 如果model是imagen3,使用paid_key + if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": + api_key = await key_manager.get_paid_key() chat_service = OpenAIChatService(settings.BASE_URL, key_manager) logger.info("-" * 50 + "chat_completion" + "-" * 50) logger.info(f"Handling chat completion request for model: {request.model}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Using API key: {api_key}") try: - response = await chat_service.create_chat_completion( - request=request, - api_key=api_key, - ) + response = await chat_service.create_image_chat_completion(request=request) # 处理流式响应 if request.stream: return StreamingResponse(response, media_type="text/event-stream") @@ -64,6 +66,25 @@ async def chat_completion( raise HTTPException(status_code=500, detail="Chat completion failed") from e +@router.post("/v1/images/generations") +@router.post("/hf/v1/images/generations") +async def generate_image( + request: ImageGenerationRequest, + _=Depends(security_service.verify_authorization), +): + logger.info("-" * 50 + "generate_image" + "-" * 50) + logger.info(f"Handling image generation request for prompt: {request.prompt}") + + try: + response = image_create_service.generate_images(request) + logger.info("Image generation request successful") + return response + + except Exception as e: + logger.error(f"Image generation request failed: {str(e)}") + raise HTTPException(status_code=500, detail="Image generation request failed") from e + + @router.post("/v1/embeddings") @router.post("/hf/v1/embeddings") async def embedding( diff --git a/app/core/config.py b/app/core/config.py index ea989b4..73e1982 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -12,6 +12,10 @@ class Settings(BaseSettings): SHOW_THINKING_PROCESS: bool = True AUTH_TOKEN: str = "" MAX_FAILURES: int = 3 + PAID_KEY: str = "" + CREATE_IMAGE_MODEL: str = "" + UPLOAD_PROVIDER: str = "smms" + SMMS_SECRET_TOKEN: str = "" def __init__(self): super().__init__() diff --git a/app/core/logger.py b/app/core/logger.py index 66aec33..d22607e 100644 --- a/app/core/logger.py +++ b/app/core/logger.py @@ -129,3 +129,7 @@ def get_request_logger(): def get_retry_logger(): return Logger.setup_logger("retry") + + +def get_image_create_logger(): + return Logger.setup_logger("image_create") diff --git a/app/core/uploader.py b/app/core/uploader.py new file mode 100644 index 0000000..f3f33de --- /dev/null +++ b/app/core/uploader.py @@ -0,0 +1,163 @@ +import requests +from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse +from enum import Enum +from typing import Optional, Any + +class UploadErrorType(Enum): + """上传错误类型枚举""" + NETWORK_ERROR = "network_error" # 网络请求错误 + AUTH_ERROR = "auth_error" # 认证错误 + INVALID_FILE = "invalid_file" # 无效文件 + SERVER_ERROR = "server_error" # 服务器错误 + PARSE_ERROR = "parse_error" # 响应解析错误 + UNKNOWN = "unknown" # 未知错误 + + +class UploadError(Exception): + """图片上传错误异常类""" + + def __init__( + self, + message: str, + error_type: UploadErrorType = UploadErrorType.UNKNOWN, + status_code: Optional[int] = None, + details: Optional[dict] = None, + original_error: Optional[Exception] = None + ): + """ + 初始化上传错误异常 + + Args: + message: 错误消息 + error_type: 错误类型 + status_code: HTTP状态码 + details: 详细错误信息 + original_error: 原始异常 + """ + self.message = message + self.error_type = error_type + self.status_code = status_code + self.details = details or {} + self.original_error = original_error + + # 构建完整错误信息 + full_message = f"[{error_type.value}] {message}" + if status_code: + full_message = f"{full_message} (Status: {status_code})" + if details: + full_message = f"{full_message} - Details: {details}" + + super().__init__(full_message) + + @classmethod + def from_response(cls, response: Any, message: Optional[str] = None) -> "UploadError": + """ + 从HTTP响应创建错误实例 + + Args: + response: HTTP响应对象 + message: 自定义错误消息 + """ + try: + error_data = response.json() + details = error_data.get("data", {}) + return cls( + message=message or error_data.get("message", "Unknown error"), + error_type=UploadErrorType.SERVER_ERROR, + status_code=response.status_code, + details=details + ) + except Exception: + return cls( + message=message or "Failed to parse error response", + error_type=UploadErrorType.PARSE_ERROR, + status_code=response.status_code + ) + + +class SmMsUploader(ImageUploader): + API_URL = "https://sm.ms/api/v2/upload" + + def __init__(self, api_key: str): + self.api_key = api_key + + def upload(self, file: bytes, filename: str) -> UploadResponse: + try: + # 准备请求头 + headers = { + "Authorization": f"Basic {self.api_key}" + } + + # 准备文件数据 + files = { + "smfile": (filename, file, "image/png") + } + + # 发送请求 + response = requests.post( + self.API_URL, + headers=headers, + files=files + ) + + # 检查响应状态 + response.raise_for_status() + + # 解析响应 + result = response.json() + + # 验证上传是否成功 + if not result.get("success"): + raise UploadError(result.get("message", "Upload failed")) + + # 转换为统一格式 + data = result["data"] + image_metadata = ImageMetadata( + width=data["width"], + height=data["height"], + filename=data["filename"], + size=data["size"], + url=data["url"], + delete_url=data["delete"] + ) + + return UploadResponse( + success=True, + code="success", + message="Upload success", + data=image_metadata + ) + + except requests.RequestException as e: + # 处理网络请求相关错误 + raise UploadError(f"Upload request failed: {str(e)}") + except (KeyError, ValueError) as e: + # 处理响应解析错误 + raise UploadError(f"Invalid response format: {str(e)}") + except Exception as e: + # 处理其他未预期的错误 + raise UploadError(f"Upload failed: {str(e)}") + + +class QiniuUploader(ImageUploader): + def __init__(self, access_key: str, secret_key: str): + self.access_key = access_key + self.secret_key = secret_key + + def upload(self, file: bytes, filename: str) -> UploadResponse: + # 实现七牛云的具体上传逻辑 + pass + + +class ImageUploaderFactory: + @staticmethod + def create(provider: str, **credentials) -> ImageUploader: + if provider == "smms": + return SmMsUploader(credentials["api_key"]) + elif provider == "qiniu": + return QiniuUploader( + credentials["access_key"], + credentials["secret_key"] + ) + raise ValueError(f"Unknown provider: {provider}") + diff --git a/app/schemas/image_models.py b/app/schemas/image_models.py new file mode 100644 index 0000000..02bf3f6 --- /dev/null +++ b/app/schemas/image_models.py @@ -0,0 +1,23 @@ +class ImageMetadata: + def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: str | None = None): + self.width = width + self.height = height + self.filename = filename + self.size = size + self.url = url + self.delete_url = delete_url + + +class UploadResponse: + def __init__(self, success: bool, code: str, message: str, data: ImageMetadata): + self.success = success + self.code = code + self.message = message + self.data = data + + +class ImageUploader: + def upload(self, file: bytes, filename: str) -> UploadResponse: + raise NotImplementedError + + diff --git a/app/schemas/openai_models.py b/app/schemas/openai_models.py index dc9b371..fb1594f 100644 --- a/app/schemas/openai_models.py +++ b/app/schemas/openai_models.py @@ -18,3 +18,13 @@ class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str = "text-embedding-004" encoding_format: Optional[str] = "float" + + +class ImageGenerationRequest(BaseModel): + model: str = "DALL-E-3" + prompt: str = "" + n: int = 1 + size: Optional[str] = "1024x1024" + quality: Optional[str] = "" + style: Optional[str] = "" + response_format: Optional[str] = "b64_json" diff --git a/app/services/chat/response_handler.py b/app/services/chat/response_handler.py index c669323..e1481e0 100644 --- a/app/services/chat/response_handler.py +++ b/app/services/chat/response_handler.py @@ -84,6 +84,47 @@ class OpenAIResponseHandler(ResponseHandler): if stream: return _handle_openai_stream_response(response, model, finish_reason) return _handle_openai_normal_response(response, model, finish_reason) + + def handle_image_chat_response(self, image_str: str, model: str, stream=False, finish_reason="stop"): + if stream: + return _handle_openai_stream_image_response(image_str,model,finish_reason) + return _handle_openai_normal_image_response(image_str,model,finish_reason) + + +def _handle_openai_stream_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {"content": image_str} if image_str else {}, + "finish_reason": finish_reason + }] + } + + +def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason: str) -> Dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": image_str + }, + "finish_reason": finish_reason + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str: diff --git a/app/services/image_create_service.py b/app/services/image_create_service.py new file mode 100644 index 0000000..5f9b086 --- /dev/null +++ b/app/services/image_create_service.py @@ -0,0 +1,81 @@ +import time +import uuid + +from google import genai +from google.genai import types +import base64 + +from app.core.config import settings +from app.core.logger import get_image_create_logger +from app.core.uploader import ImageUploaderFactory +from app.schemas.openai_models import ImageGenerationRequest + +logger = get_image_create_logger() + + +class ImageCreateService: + def __init__(self, aspect_ratio="1:1"): + self.image_model = settings.CREATE_IMAGE_MODEL + self.paid_key = settings.PAID_KEY + self.aspect_ratio = aspect_ratio + + def generate_images(self, request: ImageGenerationRequest): + client = genai.Client(api_key=self.paid_key) + if request.size == "1024x1024": + self.aspect_ratio = "1:1" + elif request.size == "1792x1024": + self.aspect_ratio = "16:9" + elif request.size == "1027x1792": + self.aspect_ratio = "9:16" + else: + raise ValueError( + f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792." + ) + + response = client.models.generate_images( + model=self.image_model, + prompt=request.prompt, + config=types.GenerateImagesConfig( + number_of_images=request.n, + output_mime_type="image/png", + aspect_ratio=self.aspect_ratio, + safety_filter_level="BLOCK_LOW_AND_ABOVE", + person_generation="ALLOW_ADULT", + # language="auto" + ), + ) + + if response.generated_images: + images_data = [] + for index, generated_image in enumerate(response.generated_images): + image_data = generated_image.image.image_bytes + image_uploader = None + if settings.UPLOAD_PROVIDER == "smms": + image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN) + current_date = time.strftime("%Y/%m/%d") + filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" + upload_response = image_uploader.upload(image_data,filename) + + # base64_image = base64.b64encode(image_data).decode('utf-8') + images_data.append({ + "url": f"{upload_response.data.url}", + "revised_prompt": request.prompt + }) + + response_data = { + "created": int(time.time()), # Current timestamp + "data": images_data + } + return response_data + else: + raise Exception("I can't generate these images") + + def generate_images_chat(self, request: ImageGenerationRequest) -> str: + response = self.generate_images(request) + image_datas = response["data"] + if image_datas: + markdown_images = [] + for index, image_data in enumerate(image_datas): + markdown_images.append(f"![Generated Image {index+1}]({image_data['url']})") + return "\n".join(markdown_images) + \ No newline at end of file diff --git a/app/services/key_manager.py b/app/services/key_manager.py index 0048591..3e474c3 100644 --- a/app/services/key_manager.py +++ b/app/services/key_manager.py @@ -15,7 +15,11 @@ class KeyManager: self.failure_count_lock = asyncio.Lock() self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys} self.MAX_FAILURES = settings.MAX_FAILURES + self.paid_key = settings.PAID_KEY + async def get_paid_key(self) -> str: + return self.paid_key + async def get_next_key(self) -> str: """获取下一个API key""" async with self.key_cycle_lock: diff --git a/app/services/model_service.py b/app/services/model_service.py index 586d1b6..fcad8e4 100644 --- a/app/services/model_service.py +++ b/app/services/model_service.py @@ -2,10 +2,10 @@ import requests from datetime import datetime, timezone from typing import Optional, Dict, Any from app.core.logger import get_model_logger +from app.core.config import settings logger = get_model_logger() - class ModelService: def __init__(self, model_search: list): self.model_search = model_search @@ -52,6 +52,11 @@ class ModelService: "parent": None, } openai_format["data"].append(openai_model) + + if settings.CREATE_IMAGE_MODEL: + image_model = openai_model.copy() + image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat" + openai_format["data"].append(image_model) if model_id in self.model_search: search_model = openai_model.copy() diff --git a/app/services/openai_chat_service.py b/app/services/openai_chat_service.py index fd25833..e814209 100644 --- a/app/services/openai_chat_service.py +++ b/app/services/openai_chat_service.py @@ -3,11 +3,11 @@ import json from typing import Dict, Any, AsyncGenerator, List, Union from app.core.logger import get_openai_logger -from app.services.chat.message_converter import OpenAIMessageConverter from app.services.chat.response_handler import OpenAIResponseHandler from app.services.chat.api_client import GeminiApiClient -from app.schemas.openai_models import ChatRequest +from app.schemas.openai_models import ChatRequest, ImageGenerationRequest from app.core.config import settings +from app.services.image_create_service import ImageCreateService from app.services.key_manager import KeyManager logger = get_openai_logger() @@ -31,9 +31,9 @@ def _build_tools( model = request.model if ( - settings.TOOLS_CODE_EXECUTION_ENABLED - and not (model.endswith("-search") or "-thinking" in model) - and not _has_image_parts(messages) + settings.TOOLS_CODE_EXECUTION_ENABLED + and not (model.endswith("-search") or "-thinking" in model) + and not _has_image_parts(messages) ): tools.append({"code_execution": {}}) if model.endswith("-search"): @@ -86,16 +86,17 @@ def _build_payload( class OpenAIChatService: """聊天服务""" - def __init__(self, base_url: str, key_manager: KeyManager): - self.message_converter = OpenAIMessageConverter() + def __init__(self, base_url: str, key_manager: KeyManager = None): + self.response_handler = OpenAIResponseHandler(config=None) self.api_client = GeminiApiClient(base_url) self.key_manager = key_manager + self.image_create_service = ImageCreateService() async def create_chat_completion( - self, - request: ChatRequest, - api_key: str, + self, + request: ChatRequest, + api_key: str, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """创建聊天完成""" # 转换消息格式 @@ -109,7 +110,7 @@ class OpenAIChatService: return self._handle_normal_completion(request.model, payload, api_key) def _handle_normal_completion( - self, model: str, payload: Dict[str, Any], api_key: str + self, model: str, payload: Dict[str, Any], api_key: str ) -> Dict[str, Any]: """处理普通聊天完成""" response = self.api_client.generate_content(payload, model, api_key) @@ -118,7 +119,7 @@ class OpenAIChatService: ) async def _handle_stream_completion( - self, model: str, payload: Dict[str, Any], api_key: str + self, model: str, payload: Dict[str, Any], api_key: str ) -> AsyncGenerator[str, None]: """处理流式聊天完成,添加重试逻辑""" retries = 0 @@ -126,7 +127,7 @@ class OpenAIChatService: while retries < max_retries: try: async for line in self.api_client.stream_generate_content( - payload, model, api_key + payload, model, api_key ): # print(line) if line.startswith("data:"): @@ -154,3 +155,38 @@ class OpenAIChatService: yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n" yield "data: [DONE]\n\n" break + + async def create_image_chat_completion( + self, + request: ChatRequest, + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + + image_generate_request = ImageGenerationRequest() + image_generate_request.prompt = request.messages[-1]["content"] + image_res = self.image_create_service.generate_images_chat(image_generate_request) + + if request.stream: + return self._handle_stream_image_completion(request.model,image_res) + else: + return self._handle_normal_image_completion(request.model,image_res) + + async def _handle_stream_image_completion( + self, model: str, image_data: str + ) -> AsyncGenerator[str, None]: + if image_data: + openai_chunk = self.response_handler.handle_image_chat_response( + image_data, model, stream=True, finish_reason=None + ) + if openai_chunk: + yield f"data: {json.dumps(openai_chunk)}\n\n" + yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" + yield "data: [DONE]\n\n" + logger.info("Image chat streaming completed successfully") + + def _handle_normal_image_completion( + self, model: str, image_data: str + ) -> Dict[str, Any]: + + return self.response_handler.handle_image_chat_response( + image_data, model, stream=False, finish_reason="stop" + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 347af8b..fd46e65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ pydantic pydantic_settings requests starlette -uvicorn \ No newline at end of file +uvicorn +google-genai \ No newline at end of file