diff --git a/.env.example b/.env.example index 2d16c1c..863a137 100644 --- a/.env.example +++ b/.env.example @@ -2,7 +2,8 @@ API_KEYS=["AIzaSyxxxxxxxxxxxxxxxxxxx","AIzaSyxxxxxxxxxxxxxxxxxxx"] ALLOWED_TOKENS=["sk-123456"] # AUTH_TOKEN=sk-123456 MODEL_SEARCH=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"] -TOOLS_CODE_EXECUTION_ENABLED=true +MODEL_IMAGE=["gemini-2.0-flash-exp"] +TOOLS_CODE_EXECUTION_ENABLED=false SHOW_SEARCH_LINK=true SHOW_THINKING_PROCESS=true BASE_URL=https://generativelanguage.googleapis.com/v1beta @@ -12,6 +13,7 @@ PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx CREATE_IMAGE_MODEL=imagen-3.0-generate-002 UPLOAD_PROVIDER=smms SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX +PICGO_API_KEY=xxxx ########################################################################## #########################stream_optimizer 相关配置######################## STREAM_MIN_DELAY=0.016 diff --git a/Dockerfile b/Dockerfile index 9c365e7..b738aa8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ COPY ./app /app/app ENV API_KEYS='["your_api_key_1"]' ENV ALLOWED_TOKENS='["your_token_1"]' ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta -ENV TOOLS_CODE_EXECUTION_ENABLED=true +ENV TOOLS_CODE_EXECUTION_ENABLED=fasle ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]' # Expose port diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index 8905257..bc5bf15 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -23,7 +23,7 @@ async def get_key_manager(): async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)): return await key_manager.get_next_working_key() -model_service = ModelService(settings.MODEL_SEARCH) +model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE) @router.get("/models") @@ -42,6 +42,12 @@ async def list_models(_=Depends(security_service.verify_key), "outputTokenLimit": 8192, "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, "topP": 0.95, "topK": 64, "maxTemperature": 2}) + models_json["models"].append({"name": "models/gemini-2.0-flash-exp-image", "version": "2.0", + "displayName": "Gemini 2.0 Flash Image Experimental", + "description": "Gemini 2.0 Flash Image Experimental", "inputTokenLimit": 32767, + "outputTokenLimit": 8192, + "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, + "topP": 0.95, "topK": 64, "maxTemperature": 2}) return models_json diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index 68f021a..e70b4a6 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -17,7 +17,7 @@ logger = get_openai_logger() # 初始化服务 security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) -model_service = ModelService(settings.MODEL_SEARCH) +model_service = ModelService(settings.MODEL_SEARCH,settings.MODEL_IMAGE) embedding_service = EmbeddingService(settings.BASE_URL) image_create_service = ImageCreateService() diff --git a/app/core/config.py b/app/core/config.py index 7faa49c..f629bf7 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -7,6 +7,7 @@ class Settings(BaseSettings): ALLOWED_TOKENS: List[str] BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta" MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"] + MODEL_IMAGE: List[str] = ["gemini-2.0-flash-exp"] TOOLS_CODE_EXECUTION_ENABLED: bool = False SHOW_SEARCH_LINK: bool = True SHOW_THINKING_PROCESS: bool = True @@ -16,6 +17,7 @@ class Settings(BaseSettings): CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002" UPLOAD_PROVIDER: str = "smms" SMMS_SECRET_TOKEN: str = "" + PICGO_API_KEY: str = "" TEST_MODEL: str = "gemini-1.5-flash" # 流式输出优化器配置 diff --git a/app/core/uploader.py b/app/core/uploader.py index f3f33de..e30211d 100644 --- a/app/core/uploader.py +++ b/app/core/uploader.py @@ -149,6 +149,116 @@ class QiniuUploader(ImageUploader): pass +class PicGoUploader(ImageUploader): + """Chevereto API 图片上传器""" + + def __init__(self, api_key: str, api_url: str = "https://www.picgo.net/api/1/upload"): + """ + 初始化 Chevereto 上传器 + + Args: + api_key: Chevereto API 密钥 + api_url: Chevereto API 上传地址 + """ + self.api_key = api_key + self.api_url = api_url + + def upload(self, file: bytes, filename: str) -> UploadResponse: + """ + 上传图片到 Chevereto 服务 + + Args: + file: 图片文件二进制数据 + filename: 文件名 + + Returns: + UploadResponse: 上传响应对象 + + Raises: + UploadError: 上传失败时抛出异常 + """ + try: + # 准备请求头 + headers = { + "X-API-Key": self.api_key + } + + # 准备文件数据 + files = { + "source": (filename, file) + } + + # 发送请求 + response = requests.post( + self.api_url, + headers=headers, + files=files + ) + + # 检查响应状态 + response.raise_for_status() + + # 解析响应 + result = response.json() + + # 验证上传是否成功 + if result.get("status_code") != 200: + error_message = "Upload failed" + if "error" in result: + error_message = result["error"].get("message", error_message) + raise UploadError( + message=error_message, + error_type=UploadErrorType.SERVER_ERROR, + status_code=result.get("status_code"), + details=result.get("error") + ) + + # 从响应中提取图片信息 + image_data = result.get("image", {}) + + # 构建图片元数据 + image_metadata = ImageMetadata( + width=image_data.get("width", 0), + height=image_data.get("height", 0), + filename=image_data.get("filename", filename), + size=image_data.get("size", 0), + url=image_data.get("url", ""), + delete_url=image_data.get("delete_url", None) + ) + + return UploadResponse( + success=True, + code="success", + message=result.get("success", {}).get("message", "Upload success"), + data=image_metadata + ) + + except requests.RequestException as e: + # 处理网络请求相关错误 + raise UploadError( + message=f"Upload request failed: {str(e)}", + error_type=UploadErrorType.NETWORK_ERROR, + original_error=e + ) + except (KeyError, ValueError, TypeError) as e: + # 处理响应解析错误 + raise UploadError( + message=f"Invalid response format: {str(e)}", + error_type=UploadErrorType.PARSE_ERROR, + original_error=e + ) + except UploadError: + # 重新抛出已经是 UploadError 类型的异常 + raise + except Exception as e: + # 处理其他未预期的错误 + raise UploadError( + message=f"Upload failed: {str(e)}", + error_type=UploadErrorType.UNKNOWN, + original_error=e + ) + + class ImageUploaderFactory: @staticmethod def create(provider: str, **credentials) -> ImageUploader: @@ -159,5 +269,7 @@ class ImageUploaderFactory: credentials["access_key"], credentials["secret_key"] ) + elif provider == "picgo": + api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload") + return PicGoUploader(credentials["api_key"], api_url) raise ValueError(f"Unknown provider: {provider}") - diff --git a/app/schemas/gemini_models.py b/app/schemas/gemini_models.py index 6869379..26515f1 100644 --- a/app/schemas/gemini_models.py +++ b/app/schemas/gemini_models.py @@ -33,8 +33,8 @@ class GeminiContent(BaseModel): class GeminiRequest(BaseModel): - contents: List[GeminiContent] + contents: List[GeminiContent] = [] tools: Optional[List[Dict[str, Any]]] = [] safetySettings: Optional[List[SafetySetting]] = None - generationConfig: Optional[GenerationConfig] = None + generationConfig: Optional[GenerationConfig] = {} systemInstruction: Optional[SystemInstruction] = None diff --git a/app/services/chat/api_client.py b/app/services/chat/api_client.py index 0a834d6..285a5dd 100644 --- a/app/services/chat/api_client.py +++ b/app/services/chat/api_client.py @@ -28,6 +28,8 @@ class GeminiApiClient(ApiClient): timeout = httpx.Timeout(self.timeout, read=self.timeout) if model.endswith("-search"): model = model[:-7] + if model.endswith("-image"): + model = model[:-6] async with httpx.AsyncClient(timeout=timeout) as client: url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" response = await client.post(url, json=payload) @@ -40,6 +42,8 @@ class GeminiApiClient(ApiClient): timeout = httpx.Timeout(self.timeout, read=self.timeout) if model.endswith("-search"): model = model[:-7] + if model.endswith("-image"): + model = model[:-6] async with httpx.AsyncClient(timeout=timeout) as client: url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" async with client.stream(method="POST", url=url, json=payload) as response: diff --git a/app/services/chat/response_handler.py b/app/services/chat/response_handler.py index 6f8d7bf..f9e4e7a 100644 --- a/app/services/chat/response_handler.py +++ b/app/services/chat/response_handler.py @@ -1,5 +1,6 @@ # app/services/chat/response_handler.py +import base64 import json import random import string @@ -8,6 +9,7 @@ from typing import Dict, Any, List, Optional import time import uuid from app.core.config import settings +from app.core.uploader import ImageUploaderFactory class ResponseHandler(ABC): @@ -135,67 +137,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, candidate = response["candidates"][0] content = candidate.get("content", {}) parts = content.get("parts", []) - # if "thinking" in model: - # if settings.SHOW_THINKING_PROCESS: - # if len(parts) == 1: - # if self.thinking_first: - # self.thinking_first = False - # self.thinking_status = True - # text = "> thinking\n\n" + parts[0].get("text") - # else: - # text = parts[0].get("text") - - # if len(parts) == 2: - # self.thinking_status = False - # if self.thinking_first: - # self.thinking_first = False - # text = ( - # "> thinking\n\n" - # + parts[0].get("text") - # + "\n\n---\n> output\n\n" - # + parts[1].get("text") - # ) - # else: - # text = ( - # parts[0].get("text") - # + "\n\n---\n> output\n\n" - # + parts[1].get("text") - # ) - # else: - # if len(parts) == 1: - # if self.thinking_first: - # self.thinking_first = False - # self.thinking_status = True - # text = "" - # elif self.thinking_status: - # text = "" - # else: - # text = parts[0].get("text") - - # if len(parts) == 2: - # self.thinking_status = False - # if self.thinking_first: - # self.thinking_first = False - # text = parts[1].get("text") - # else: - # text = parts[1].get("text") - # else: - # if "text" in parts[0]: - # text = parts[0].get("text") - # elif "executableCode" in parts[0]: - # text = _format_code_block(parts[0]["executableCode"]) - # elif "codeExecution" in parts[0]: - # text = _format_code_block(parts[0]["codeExecution"]) - # elif "executableCodeResult" in parts[0]: - # text = _format_execution_result( - # parts[0]["executableCodeResult"] - # ) - # elif "codeExecutionResult" in parts[0]: - # text = _format_execution_result( - # parts[0]["codeExecutionResult"] - # ) - # else: - # text = "" + if not parts: + return "", [] if "text" in parts[0]: text = parts[0].get("text") elif "executableCode" in parts[0]: @@ -210,6 +153,8 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, text = _format_execution_result( parts[0]["codeExecutionResult"] ) + elif "inlineData" in parts[0]: + text = _extract_image_data(parts[0]) else: text = "" text = _add_search_link_text(model, candidate, text) @@ -235,14 +180,38 @@ def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, text = candidate["content"]["parts"][0]["text"] else: text = "" - for part in candidate["content"]["parts"]: - text += part.get("text", "") + if "parts" in candidate["content"]: + for part in candidate["content"]["parts"]: + if "text" in part: + text += part["text"] + elif "inlineData" in part: + text += _extract_image_data(part) + + text = _add_search_link_text(model, candidate, text) tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format) else: text = "暂无返回" return text, tool_calls +def _extract_image_data(part: dict) -> str: + image_uploader = None + if settings.UPLOAD_PROVIDER == "smms": + image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.SMMS_SECRET_TOKEN) + elif settings.UPLOAD_PROVIDER == "picgo": + image_uploader = ImageUploaderFactory.create(provider=settings.UPLOAD_PROVIDER,api_key=settings.PICGO_API_KEY) + current_date = time.strftime("%Y/%m/%d") + filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" + base64_data = part["inlineData"]["data"] + #将base64_data转成bytes数组 + bytes_data = base64.b64decode(base64_data) + upload_response = image_uploader.upload(bytes_data,filename) + if upload_response.success: + text = f"\n![image]({upload_response.data.url})\n" + else: + text = "" + return text + def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]: """提取工具调用信息""" if not parts or not isinstance(parts, list): diff --git a/app/services/gemini_chat_service.py b/app/services/gemini_chat_service.py index 51bac9f..7936249 100644 --- a/app/services/gemini_chat_service.py +++ b/app/services/gemini_chat_service.py @@ -62,14 +62,19 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]: def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: """构建请求payload""" - payload = request.model_dump() - return { - "contents": payload.get("contents", []), - "tools": _build_tools(model, payload), + request_dict = request.model_dump() + payload = { + "contents": request_dict.get("contents", []), + "tools": _build_tools(model, request_dict), "safetySettings": _get_safety_settings(model), - "generationConfig": payload.get("generationConfig", {}), - "systemInstruction": payload.get("systemInstruction", []) + "generationConfig": request_dict.get("generationConfig", {}), + "systemInstruction": request_dict.get("systemInstruction", "") } + + if model.endswith("-image"): + payload.pop("systemInstruction") + payload["generationConfig"]["responseModalities"] = ["Text","Image"] + return payload class GeminiChatService: diff --git a/app/services/model_service.py b/app/services/model_service.py index b0a2fbd..9239dfc 100644 --- a/app/services/model_service.py +++ b/app/services/model_service.py @@ -7,8 +7,9 @@ from app.core.config import settings logger = get_model_logger() class ModelService: - def __init__(self, model_search: list): + def __init__(self, model_search: list, model_image: list): self.model_search = model_search + self.model_image = model_image self.base_url = "https://generativelanguage.googleapis.com/v1beta" def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]: @@ -57,6 +58,10 @@ class ModelService: search_model = openai_model.copy() search_model["id"] = f"{model_id}-search" openai_format["data"].append(search_model) + if model_id in self.model_image: + image_model = openai_model.copy() + image_model["id"] = f"{model_id}-image" + openai_format["data"].append(image_model) if settings.CREATE_IMAGE_MODEL: image_model = openai_model.copy() diff --git a/app/services/openai_chat_service.py b/app/services/openai_chat_service.py index 3c531e5..10b07fb 100644 --- a/app/services/openai_chat_service.py +++ b/app/services/openai_chat_service.py @@ -35,7 +35,7 @@ def _build_tools( if ( settings.TOOLS_CODE_EXECUTION_ENABLED - and not (model.endswith("-search") or "-thinking" in model) + and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image")) and not _has_image_parts(messages) ): tools.append({"code_execution": {}}) @@ -110,12 +110,15 @@ def _build_payload( "tools": _build_tools(request, messages), "safetySettings": _get_safety_settings(request.model), } - + if request.model.endswith("-image"): + payload["generationConfig"]["responseModalities"] = ["Text","Image"] + if ( instruction and isinstance(instruction, dict) and instruction.get("role") == "system" and instruction.get("parts") + and not request.model.endswith("-image") ): payload["systemInstruction"] = instruction