diff --git a/.vscode/launch.json b/.vscode/launch.json index 0eca816..4a979d6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,8 +10,12 @@ "request": "launch", "module": "uvicorn", "args": [ - "main:app", - "--reload" + "app.main:app", + "--reload", + "--host", + "127.0.0.1", + "--port", + "8000" ], "jinja": true } diff --git a/app/api/dependencies.py b/app/api/dependencies.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/routes.py b/app/api/routes.py new file mode 100644 index 0000000..cbbdf80 --- /dev/null +++ b/app/api/routes.py @@ -0,0 +1,98 @@ +from fastapi import APIRouter, Depends, Header +from typing import Optional +import logging +from fastapi.responses import StreamingResponse + +from app.core.security import SecurityService +from app.services.key_manager import KeyManager +from app.services.model_service import ModelService +from app.services.chat_service import ChatService +from app.services.embedding_service import EmbeddingService +from app.schemas.request_model import ChatRequest, EmbeddingRequest +from app.core.config import settings + +router = APIRouter() +logger = logging.getLogger(__name__) + +# 初始化服务 +security_service = SecurityService(settings.ALLOWED_TOKENS) +key_manager = KeyManager(settings.API_KEYS) +model_service = ModelService(settings.MODEL_SEARCH) +chat_service = ChatService(settings.BASE_URL, key_manager) +embedding_service = EmbeddingService(settings.BASE_URL) + + +@router.get("/v1/models") +@router.get("/hf/v1/models") +async def list_models( + authorization: str = Header(None), + token: str = Depends(security_service.verify_authorization), +): + logger.info("Handling models list request") + api_key = await key_manager.get_next_working_key() + logger.info(f"Using API key: {api_key}") + return model_service.get_gemini_models(api_key) + + +@router.post("/v1/chat/completions") +@router.post("/hf/v1/chat/completions") +async def chat_completion( + request: ChatRequest, + authorization: str = Header(None), + token: str = Depends(security_service.verify_authorization), +): + logger.info(f"Handling chat completion request for model: {request.model}") + api_key = await key_manager.get_next_working_key() + logger.info(f"Using API key: {api_key}") + retries = 0 + MAX_RETRIES = 3 + + while retries < MAX_RETRIES: + try: + response = await chat_service.create_chat_completion( + messages=request.messages, + model=request.model, + temperature=request.temperature, + stream=request.stream, + api_key=api_key, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + + # 处理流式响应 + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + return response + + except Exception as e: + logger.warning( + f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}" + ) + api_key = await key_manager.handle_api_failure(api_key) + logger.info(f"Switched to new API key: {api_key}") + retries += 1 + if retries >= MAX_RETRIES: + logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error") + raise + + +@router.post("/v1/embeddings") +@router.post("/hf/v1/embeddings") +async def embedding( + request: EmbeddingRequest, + authorization: str = Header(None), + token: str = Depends(security_service.verify_authorization), +): + logger.info(f"Handling embedding request for model: {request.model}") + api_key = await key_manager.get_next_working_key() + logger.info(f"Using API key: {api_key}") + try: + response = await embedding_service.create_embedding( + input_text=request.input, model=request.model, api_key=api_key + ) + logger.info("Embedding request successful") + return response + except Exception as e: + logger.error(f"Embedding request failed: {str(e)}") + raise diff --git a/app/config.py b/app/config.py deleted file mode 100644 index 758f55b..0000000 --- a/app/config.py +++ /dev/null @@ -1,20 +0,0 @@ -from pydantic_settings import BaseSettings -import os -from typing import List - -class Settings(BaseSettings): - API_KEYS: List[str] - ALLOWED_TOKENS: List[str] - BASE_URL: str - MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"] - - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - case_sensitive = True - # 同时从环境变量和.env文件获取配置 - env_nested_delimiter = "__" - extra = "ignore" - -# 优先从环境变量获取,如果没有则从.env文件获取 -settings = Settings(_env_file=os.getenv("ENV_FILE", ".env")) \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..38c82da --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,16 @@ +from pydantic_settings import BaseSettings +from typing import List + + +class Settings(BaseSettings): + API_KEYS: List[str] + ALLOWED_TOKENS: List[str] + BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta" + MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"] + TOOLS_CODE_EXECUTION_ENABLED: bool = False + + class Config: + env_file = ".env" + + +settings = Settings() diff --git a/app/core/security.py b/app/core/security.py new file mode 100644 index 0000000..00e13be --- /dev/null +++ b/app/core/security.py @@ -0,0 +1,30 @@ +from fastapi import HTTPException, Header +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class SecurityService: + def __init__(self, allowed_tokens: list): + self.allowed_tokens = allowed_tokens + + async def verify_authorization( + self, authorization: Optional[str] = Header(None) + ) -> str: + if not authorization: + logger.error("Missing Authorization header") + raise HTTPException(status_code=401, detail="Missing Authorization header") + + if not authorization.startswith("Bearer "): + logger.error("Invalid Authorization header format") + raise HTTPException( + status_code=401, detail="Invalid Authorization header format" + ) + + token = authorization.replace("Bearer ", "") + if token not in self.allowed_tokens: + logger.error("Invalid token") + raise HTTPException(status_code=401, detail="Invalid token") + + return token diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..fab4a88 --- /dev/null +++ b/app/main.py @@ -0,0 +1,38 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +import logging + +from app.api.routes import router +import uvicorn + +# 配置日志 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +app = FastAPI() + +# 允许跨域 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 包含所有路由 +app.include_router(router) + + +@app.get("/health") +@app.get("/") +async def health_check(): + logger.info("Health check endpoint called") + return {"status": "healthy"} + + +if __name__ == "__main__": + + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/app/schemas/request_model.py b/app/schemas/request_model.py new file mode 100644 index 0000000..1b6807c --- /dev/null +++ b/app/schemas/request_model.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel +from typing import List, Optional, Union + + +class ChatRequest(BaseModel): + messages: List[dict] + model: str = "gemini-1.5-flash-002" + temperature: Optional[float] = 0.7 + stream: Optional[bool] = False + tools: Optional[List[dict]] = [] + tool_choice: Optional[str] = "auto" + + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str = "text-embedding-004" + encoding_format: Optional[str] = "float" diff --git a/app/schemas/response_model.py b/app/schemas/response_model.py new file mode 100644 index 0000000..e69de29 diff --git a/app/services/chat_service.py b/app/services/chat_service.py new file mode 100644 index 0000000..6668ada --- /dev/null +++ b/app/services/chat_service.py @@ -0,0 +1,299 @@ +import httpx +import json +import time +import uuid +import logging +from typing import Dict, Any, Optional, AsyncGenerator, Union +import openai +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class ChatService: + def __init__(self, base_url: str, key_manager=None): + self.base_url = base_url + self.key_manager = key_manager + + def convert_messages_to_gemini_format(self, messages: list) -> list: + """Convert OpenAI message format to Gemini format""" + converted_messages = [] + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + parts = [] + + # 处理文本内容 + if isinstance(msg["content"], str): + parts.append({"text": msg["content"]}) + # 处理包含图片的消息 + elif isinstance(msg["content"], list): + for content in msg["content"]: + if isinstance(content, str): + parts.append({"text": content}) + elif isinstance(content, dict) and content["type"] == "text": + parts.append({"text": content["text"]}) + elif isinstance(content, dict) and content["type"] == "image_url": + # 处理图片URL + image_url = content["image_url"]["url"] + if image_url.startswith("data:image"): + # 处理base64图片 + parts.append( + { + "inline_data": { + "mime_type": "image/jpeg", + "data": image_url.split(",")[1], + } + } + ) + else: + # 处理普通URL图片 + parts.append( + { + "inline_data": { + "mime_type": "image/jpeg", + "data": image_url, + } + } + ) + + converted_messages.append({"role": role, "parts": parts}) + + return converted_messages + + def convert_gemini_response_to_openai( + self, response: Dict[str, Any], model: str, stream: bool = False + ) -> Optional[Dict[str, Any]]: + """Convert Gemini response to OpenAI format""" + if stream: + if not response.get("candidates"): + return None + + try: + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + + if not parts: + return None + + if "text" in parts[0]: + text = parts[0].get("text") + elif "executableCode" in parts[0]: + text = self.format_code_block(parts[0]["executableCode"]) + elif "executableCodeResult" in parts[0]: + text = self.format_execution_result(parts[0]["executableCodeResult"]) + else: + text = "" + + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": text}, + "finish_reason": None, + } + ], + } + except Exception as e: + logger.error(f"Error converting Gemini response: {str(e)}") + logger.debug(f"Raw response: {response}") + return None + else: + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response["candidates"][0]["content"]["parts"][0][ + "text" + ], + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + + async def create_chat_completion( + self, + messages: list, + model: str, + temperature: float, + stream: bool, + api_key: str, + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """Create chat completion using either Gemini or OpenAI API""" + + if tools is None: + tools = [] + if settings.TOOLS_CODE_EXECUTION_ENABLED: + tools.append({"code_execution": {}}) + if model.endswith("-search"): + tools.append({"googleSearch": {}}) + return await self._gemini_chat_completion( + messages, model, temperature, stream, api_key, tools + ) + # else: + # return await self._openai_chat_completion( + # messages, model, temperature, stream, api_key, tools + # ) + + async def _gemini_chat_completion( + self, + messages: list, + model: str, + temperature: float, + stream: bool, + api_key: str, + tools: Optional[list] = None, + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """Handle Gemini API chat completion""" + if model.endswith("-search"): + gemini_model = model[:-7] # Remove -search suffix + else: + gemini_model = model + gemini_messages = self.convert_messages_to_gemini_format(messages) + + payload = { + "contents": gemini_messages, + "generationConfig": {"temperature": temperature}, + "tools": tools, + } + + if stream: + + async def generate(): + retries = 0 + MAX_RETRIES = 3 + current_api_key = api_key + + while retries < MAX_RETRIES: + try: + async with httpx.AsyncClient() as client: + stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}" + async with client.stream( + "POST", stream_url, json=payload + ) as response: + if response.status_code != 200: + if retries < MAX_RETRIES - 1: + logger.warning( + f"API error: {response.status_code}, attempting retry {retries + 1}" + ) + current_api_key = ( + await self.key_manager.handle_api_failure( + current_api_key + ) + ) + logger.info( + f"Switched to new API key: {current_api_key}" + ) + retries += 1 + continue + else: + logger.error( + f"Max retries reached. Final error: {response.status_code}" + ) + yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n" + return + + async for line in response.aiter_lines(): + if line.startswith("data: "): + try: + chunk = json.loads(line[6:]) + openai_chunk = ( + self.convert_gemini_response_to_openai( + chunk, model, stream=True + ) + ) + if openai_chunk: + yield f"data: {json.dumps(openai_chunk)}\n\n" + except json.JSONDecodeError: + continue + yield "data: [DONE]\n\n" + return # 成功完成,退出重试循环 + + except Exception as e: + if retries < MAX_RETRIES - 1: + logger.warning( + f"Stream error: {str(e)}, attempting retry {retries + 1}" + ) + current_api_key = await self.key_manager.handle_api_failure( + current_api_key + ) + retries += 1 + continue + else: + logger.error(f"Max retries reached. Final error: {str(e)}") + yield f"data: {json.dumps({'error': str(e)})}\n\n" + return + + return generate() + else: + async with httpx.AsyncClient() as client: + url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}" + response = await client.post(url, json=payload) + gemini_response = response.json() + return self.convert_gemini_response_to_openai(gemini_response, model) + + async def _openai_chat_completion( + self, + messages: list, + model: str, + temperature: float, + stream: bool, + api_key: str, + tools: Optional[list] = None, + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """Handle OpenAI API chat completion""" + client = openai.OpenAI(api_key=api_key, base_url=self.base_url) + if tools: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + stream=stream, + tools=tools, + ) + else: + response = client.chat.completions.create( + model=model, messages=messages, temperature=temperature, stream=stream + ) + + if stream: + + async def generate(): + for chunk in response: + yield f"data: {chunk.model_dump_json()}\n\n" + + return generate() + + return response + + def format_code_block(self, code_data: dict) -> str: + """格式化代码块输出""" + language = code_data.get("language", "").lower() + code = code_data.get("code", "").strip() + + return f"""\n```{language}\n{code}\n```\n""" + + + def format_execution_result(result_data: dict) -> str: + """格式化执行结果输出""" + outcome = result_data.get("outcome", "") + output = result_data.get("output", "").strip() + return f"""\n【执行结果】\n{output}\n""" diff --git a/app/services/embedding_service.py b/app/services/embedding_service.py new file mode 100644 index 0000000..bb8e3c1 --- /dev/null +++ b/app/services/embedding_service.py @@ -0,0 +1,22 @@ +import logging +import openai +from typing import Union, List, Dict, Any + +logger = logging.getLogger(__name__) + + +class EmbeddingService: + def __init__(self, base_url: str): + self.base_url = base_url + + async def create_embedding( + self, input_text: Union[str, List[str]], model: str, api_key: str + ) -> Dict[str, Any]: + """Create embeddings using OpenAI API""" + try: + client = openai.OpenAI(api_key=api_key, base_url=self.base_url) + response = client.embeddings.create(input=input_text, model=model) + return response + except Exception as e: + logger.error(f"Error creating embedding: {str(e)}") + raise diff --git a/app/services/key_manager.py b/app/services/key_manager.py new file mode 100644 index 0000000..c9cb5ca --- /dev/null +++ b/app/services/key_manager.py @@ -0,0 +1,57 @@ +import asyncio +from itertools import cycle +import logging +from typing import Dict + +logger = logging.getLogger(__name__) + + +class KeyManager: + def __init__(self, api_keys: list): + self.api_keys = api_keys + self.key_cycle = cycle(api_keys) + self.key_cycle_lock = asyncio.Lock() + self.failure_count_lock = asyncio.Lock() + self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys} + self.MAX_FAILURES = 10 + + async def get_next_key(self) -> str: + """获取下一个API key""" + async with self.key_cycle_lock: + return next(self.key_cycle) + + async def is_key_valid(self, key: str) -> bool: + """检查key是否有效""" + async with self.failure_count_lock: + return self.key_failure_counts[key] < self.MAX_FAILURES + + async def reset_failure_counts(self): + """重置所有key的失败计数""" + async with self.failure_count_lock: + for key in self.key_failure_counts: + self.key_failure_counts[key] = 0 + + async def get_next_working_key(self) -> str: + """获取下一个可用的API key""" + initial_key = await self.get_next_key() + current_key = initial_key + + while True: + if await self.is_key_valid(current_key): + return current_key + + current_key = await self.get_next_key() + if current_key == initial_key: + await self.reset_failure_counts() + return current_key + + async def handle_api_failure(self, api_key: str) -> str: + """处理API调用失败""" + async with self.failure_count_lock: + self.key_failure_counts[api_key] += 1 + if self.key_failure_counts[api_key] >= self.MAX_FAILURES: + logger.warning( + f"API key {api_key} has failed {self.MAX_FAILURES} times" + ) + + return await self.get_next_working_key() diff --git a/app/services/model_service.py b/app/services/model_service.py new file mode 100644 index 0000000..49fcdd5 --- /dev/null +++ b/app/services/model_service.py @@ -0,0 +1,55 @@ +import requests +from datetime import datetime, timezone +from typing import Optional, Dict, Any +import logging + +logger = logging.getLogger(__name__) + + +class ModelService: + def __init__(self, model_search: list): + self.model_search = model_search + + def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]: + base_url = "https://generativelanguage.googleapis.com/v1beta" + url = f"{base_url}/models?key={api_key}" + + try: + response = requests.get(url) + if response.status_code == 200: + gemini_models = response.json() + return self.convert_to_openai_models_format(gemini_models) + else: + logger.error(f"Error: {response.status_code}") + logger.error(response.text) + return None + + except requests.RequestException as e: + logger.error(f"Request failed: {e}") + return None + + def convert_to_openai_models_format( + self, gemini_models: Dict[str, Any] + ) -> Dict[str, Any]: + openai_format = {"object": "list", "data": []} + + for model in gemini_models.get("models", []): + model_id = model["name"].split("/")[-1] + openai_model = { + "id": model_id, + "object": "model", + "created": int(datetime.now(timezone.utc).timestamp()), + "owned_by": "google", + "permission": [], + "root": model["name"], + "parent": None, + "success": True, + } + openai_format["data"].append(openai_model) + + if model_id in self.model_search: + search_model = openai_model.copy() + search_model["id"] = f"{model_id}-search" + openai_format["data"].append(search_model) + + return openai_format diff --git a/app/utils/helpers.py b/app/utils/helpers.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py deleted file mode 100644 index d9be0c4..0000000 --- a/main.py +++ /dev/null @@ -1,416 +0,0 @@ -from fastapi import FastAPI, HTTPException, Header -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -import openai -from typing import List, Optional, Union -import logging -from itertools import cycle -import asyncio - -import uvicorn - -from app import config -import requests -from datetime import datetime, timezone -import json -import httpx -import uuid -import time - -# 配置日志 -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - -app = FastAPI() - -# 允许跨域 -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# API密钥配置 -API_KEYS = config.settings.API_KEYS - -# 创建一个循环迭代器 -key_cycle = cycle(API_KEYS) - -# 创建两个独立的锁 -key_cycle_lock = asyncio.Lock() -failure_count_lock = asyncio.Lock() - -# 添加key失败计数记录 -key_failure_counts = {key: 0 for key in API_KEYS} -MAX_FAILURES = 10 # 最大失败次数阈值 -MAX_RETRIES = 3 # 最大重试次数 - - -async def get_next_key(): - """仅获取下一个key,不检查失败次数""" - async with key_cycle_lock: - return next(key_cycle) - - -async def is_key_valid(key): - """检查key是否有效""" - async with failure_count_lock: - return key_failure_counts[key] < MAX_FAILURES - - -async def reset_failure_counts(): - """重置所有key的失败计数""" - async with failure_count_lock: - for key in key_failure_counts: - key_failure_counts[key] = 0 - - -async def get_next_working_key(): - """获取下一个可用的API key""" - initial_key = await get_next_key() - current_key = initial_key - - while True: - if await is_key_valid(current_key): - return current_key - - current_key = await get_next_key() - if current_key == initial_key: # 已经循环了一圈 - await reset_failure_counts() - return current_key - - -async def handle_api_failure(api_key): - """处理API调用失败""" - async with failure_count_lock: - key_failure_counts[api_key] += 1 - if key_failure_counts[api_key] >= MAX_FAILURES: - logger.warning( - f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key" - ) - - # 在锁外获取新的key - return await get_next_working_key() - - -class ChatRequest(BaseModel): - messages: List[dict] - model: str = "gemini-1.5-flash-002" - temperature: Optional[float] = 0.7 - stream: Optional[bool] = False - tools: Optional[List[dict]] = [] - tool_choice: Optional[str] = "auto" - - -class EmbeddingRequest(BaseModel): - input: Union[str, List[str]] - model: str = "text-embedding-004" - encoding_format: Optional[str] = "float" - - -async def verify_authorization(authorization: str = Header(None)): - if not authorization: - logger.error("Missing Authorization header") - raise HTTPException(status_code=401, detail="Missing Authorization header") - if not authorization.startswith("Bearer "): - logger.error("Invalid Authorization header format") - raise HTTPException( - status_code=401, detail="Invalid Authorization header format" - ) - token = authorization.replace("Bearer ", "") - if token not in config.settings.ALLOWED_TOKENS: - logger.error("Invalid token") - raise HTTPException(status_code=401, detail="Invalid token") - return token - - -def get_gemini_models(api_key): - base_url = "https://generativelanguage.googleapis.com/v1beta" - url = f"{base_url}/models?key={api_key}" - - try: - response = requests.get(url) - if response.status_code == 200: - gemini_models = response.json() - return convert_to_openai_models_format(gemini_models) - else: - print(f"Error: {response.status_code}") - print(response.text) - return None - - except requests.RequestException as e: - print(f"Request failed: {e}") - return None - - -def convert_to_openai_models_format(gemini_models): - openai_format = {"object": "list", "data": []} - - # 添加常规模型 - for model in gemini_models.get("models", []): - model_id = model["name"].split("/")[-1] - openai_model = { - "id": model_id, - "object": "model", - "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳 - "owned_by": "google", # 假设所有Gemini模型都由Google拥有 - "permission": [], # Gemini API可能没有直接对应的权限信息 - "root": model["name"], - "parent": None, # Gemini API可能没有直接对应的父模型信息 - } - openai_format["data"].append(openai_model) - - # 如果模型在 MODEL_SEARCH 中,添加带 search 后缀的版本 - if model_id in config.settings.MODEL_SEARCH: - search_model = openai_model.copy() - search_model["id"] = f"{model_id}-search" - openai_format["data"].append(search_model) - - return openai_format - - -def convert_messages_to_gemini_format(messages): - """Convert OpenAI message format to Gemini format""" - gemini_messages = [] - for message in messages: - gemini_message = { - "role": "user" if message["role"] == "user" else "model", - "parts": [{"text": message["content"]}], - } - gemini_messages.append(gemini_message) - return gemini_messages - - -def convert_gemini_response_to_openai(response, model, stream=False): - """Convert Gemini response to OpenAI format""" - if stream: - # 处理流式响应 - chunk = response - if not chunk["candidates"]: - return None - - return { - "id": "chatcmpl-" + str(uuid.uuid4()), - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": chunk["candidates"][0]["content"]["parts"][0]["text"] - }, - "finish_reason": None, - } - ], - } - else: - # 处理普通响应 - return { - "id": "chatcmpl-" + str(uuid.uuid4()), - "object": "chat.completion", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": response["candidates"][0]["content"]["parts"][0][ - "text" - ], - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - } - - -@app.get("/v1/models") -@app.get("/hf/v1/models") -async def list_models(authorization: str = Header(None)): - await verify_authorization(authorization) - api_key = await get_next_working_key() - logger.info(f"Using API key: {api_key}") - try: - response = get_gemini_models(api_key) - logger.info("Successfully retrieved models list") - return response - except Exception as e: - logger.error(f"Error listing models: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/v1/chat/completions") -@app.post("/hf/v1/chat/completions") -async def chat_completion(request: ChatRequest, authorization: str = Header(None)): - await verify_authorization(authorization) - api_key = await get_next_working_key() - logger.info(f"Chat completion request - Model: {request.model}") - retries = 0 - - while retries < MAX_RETRIES: - try: - logger.info(f"Attempt {retries + 1} with API key: {api_key}") - - # 修改判断条件,检查模型名是否以 -search 结尾 - if request.model.endswith("-search"): - # 去掉 -search 后缀 - gemini_model = request.model[:-7] - - # Gemini API调用部分 - gemini_messages = convert_messages_to_gemini_format(request.messages) - # 调用Gemini API - payload = { - "contents": gemini_messages, - "generationConfig": { - "temperature": request.temperature, - }, - "tools": [{"googleSearch": {}}], - } - - if request.stream: - logger.info("Streaming response enabled") - - async def generate(): - nonlocal api_key, retries - while retries < MAX_RETRIES: - try: - async with httpx.AsyncClient() as client: - stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}" - async with client.stream( - "POST", stream_url, json=payload - ) as response: - if response.status_code == 429: - logger.warning( - f"Rate limit reached for key: {api_key}" - ) - api_key = await handle_api_failure(api_key) - logger.info( - f"Retrying with new API key: {api_key}" - ) - retries += 1 - if retries >= MAX_RETRIES: - yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n" - break - continue - - if response.status_code != 200: - logger.error( - f"Error in streaming response: {response.status_code}" - ) - yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n" - break - - async for line in response.aiter_lines(): - if line.startswith("data: "): - try: - chunk = json.loads(line[6:]) - openai_chunk = convert_gemini_response_to_openai( - chunk, - request.model, - stream=True, - ) - if openai_chunk: - yield f"data: {json.dumps(openai_chunk)}\n\n" - except json.JSONDecodeError: - continue - yield "data: [DONE]\n\n" - return - except Exception as e: - logger.error(f"Stream error: {str(e)}") - api_key = await handle_api_failure(api_key) - retries += 1 - if retries >= MAX_RETRIES: - yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n" - break - continue - - return StreamingResponse( - content=generate(), media_type="text/event-stream" - ) - else: - # 非流式响应 - async with httpx.AsyncClient() as client: - non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}" - response = await client.post(non_stream_url, json=payload) - gemini_response = response.json() - logger.info("Chat completion successful") - return convert_gemini_response_to_openai( - gemini_response, request.model - ) - - # OpenAI API调用部分 - client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) - response = client.chat.completions.create( - model=request.model, - messages=request.messages, - temperature=request.temperature, - stream=request.stream if hasattr(request, "stream") else False, - ) - - if hasattr(request, "stream") and request.stream: - logger.info("Streaming response enabled") - - async def generate(): - for chunk in response: - yield f"data: {chunk.model_dump_json()}\n\n" - - logger.info("Chat completion successful") - return StreamingResponse( - content=generate(), media_type="text/event-stream" - ) - - logger.info("Chat completion successful") - return response - - except Exception as e: - logger.error(f"Error in chat completion: {str(e)}") - api_key = await handle_api_failure(api_key) - retries += 1 - - if retries >= MAX_RETRIES: - logger.error("Max retries reached, giving up") - raise HTTPException( - status_code=500, - detail="Max retries reached with all available API keys", - ) - - logger.info(f"Retrying with new API key: {api_key}") - continue - - raise HTTPException(status_code=500, detail="Unexpected error in chat completion") - - -@app.post("/v1/embeddings") -@app.post("/hf/v1/embeddings") -async def embedding(request: EmbeddingRequest, authorization: str = Header(None)): - await verify_authorization(authorization) - api_key = await get_next_working_key() - logger.info(f"Using API key: {api_key}") - - try: - client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) - response = client.embeddings.create(input=request.input, model=request.model) - logger.info("Embedding successful") - return response - except Exception as e: - logger.error(f"Error in embedding: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/health") -@app.get("/") -async def health_check(): - logger.info("Health check endpoint called") - return {"status": "healthy"} - - -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000)