添加API密钥管理、模型服务和安全服务,并优化FastAPI应用程序配置

This commit is contained in:
yinpeng
2024-12-15 11:08:35 +08:00
parent 03c201c849
commit c56bea0b25
15 changed files with 638 additions and 438 deletions

8
.vscode/launch.json vendored
View File

@@ -10,8 +10,12 @@
"request": "launch",
"module": "uvicorn",
"args": [
"main:app",
"--reload"
"app.main:app",
"--reload",
"--host",
"127.0.0.1",
"--port",
"8000"
],
"jinja": true
}

0
app/api/dependencies.py Normal file
View File

98
app/api/routes.py Normal file
View File

@@ -0,0 +1,98 @@
from fastapi import APIRouter, Depends, Header
from typing import Optional
import logging
from fastapi.responses import StreamingResponse
from app.core.security import SecurityService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat_service import ChatService
from app.services.embedding_service import EmbeddingService
from app.schemas.request_model import ChatRequest, EmbeddingRequest
from app.core.config import settings
router = APIRouter()
logger = logging.getLogger(__name__)
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS)
key_manager = KeyManager(settings.API_KEYS)
model_service = ModelService(settings.MODEL_SEARCH)
chat_service = ChatService(settings.BASE_URL, key_manager)
embedding_service = EmbeddingService(settings.BASE_URL)
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
):
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
return model_service.get_gemini_models(api_key)
@router.post("/v1/chat/completions")
@router.post("/hf/v1/chat/completions")
async def chat_completion(
request: ChatRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
):
logger.info(f"Handling chat completion request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
retries = 0
MAX_RETRIES = 3
while retries < MAX_RETRIES:
try:
response = await chat_service.create_chat_completion(
messages=request.messages,
model=request.model,
temperature=request.temperature,
stream=request.stream,
api_key=api_key,
tools=request.tools,
tool_choice=request.tool_choice,
)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
except Exception as e:
logger.warning(
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}"
)
api_key = await key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
retries += 1
if retries >= MAX_RETRIES:
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
raise
@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
):
logger.info(f"Handling embedding request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
response = await embedding_service.create_embedding(
input_text=request.input, model=request.model, api_key=api_key
)
logger.info("Embedding request successful")
return response
except Exception as e:
logger.error(f"Embedding request failed: {str(e)}")
raise

View File

@@ -1,20 +0,0 @@
from pydantic_settings import BaseSettings
import os
from typing import List
class Settings(BaseSettings):
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
BASE_URL: str
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = True
# 同时从环境变量和.env文件获取配置
env_nested_delimiter = "__"
extra = "ignore"
# 优先从环境变量获取,如果没有则从.env文件获取
settings = Settings(_env_file=os.getenv("ENV_FILE", ".env"))

16
app/core/config.py Normal file
View File

@@ -0,0 +1,16 @@
from pydantic_settings import BaseSettings
from typing import List
class Settings(BaseSettings):
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
TOOLS_CODE_EXECUTION_ENABLED: bool = False
class Config:
env_file = ".env"
settings = Settings()

30
app/core/security.py Normal file
View File

@@ -0,0 +1,30 @@
from fastapi import HTTPException, Header
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class SecurityService:
def __init__(self, allowed_tokens: list):
self.allowed_tokens = allowed_tokens
async def verify_authorization(
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing Authorization header")
raise HTTPException(status_code=401, detail="Missing Authorization header")
if not authorization.startswith("Bearer "):
logger.error("Invalid Authorization header format")
raise HTTPException(
status_code=401, detail="Invalid Authorization header format"
)
token = authorization.replace("Bearer ", "")
if token not in self.allowed_tokens:
logger.error("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
return token

38
app/main.py Normal file
View File

@@ -0,0 +1,38 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import logging
from app.api.routes import router
import uvicorn
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
# 允许跨域
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 包含所有路由
app.include_router(router)
@app.get("/health")
@app.get("/")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@@ -0,0 +1,17 @@
from pydantic import BaseModel
from typing import List, Optional, Union
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "gemini-1.5-flash-002"
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
tool_choice: Optional[str] = "auto"
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-004"
encoding_format: Optional[str] = "float"

View File

View File

@@ -0,0 +1,299 @@
import httpx
import json
import time
import uuid
import logging
from typing import Dict, Any, Optional, AsyncGenerator, Union
import openai
from app.core.config import settings
logger = logging.getLogger(__name__)
class ChatService:
def __init__(self, base_url: str, key_manager=None):
self.base_url = base_url
self.key_manager = key_manager
def convert_messages_to_gemini_format(self, messages: list) -> list:
"""Convert OpenAI message format to Gemini format"""
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
# 处理文本内容
if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
# 处理包含图片的消息
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str):
parts.append({"text": content})
elif isinstance(content, dict) and content["type"] == "text":
parts.append({"text": content["text"]})
elif isinstance(content, dict) and content["type"] == "image_url":
# 处理图片URL
image_url = content["image_url"]["url"]
if image_url.startswith("data:image"):
# 处理base64图片
parts.append(
{
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1],
}
}
)
else:
# 处理普通URL图片
parts.append(
{
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url,
}
}
)
converted_messages.append({"role": role, "parts": parts})
return converted_messages
def convert_gemini_response_to_openai(
self, response: Dict[str, Any], model: str, stream: bool = False
) -> Optional[Dict[str, Any]]:
"""Convert Gemini response to OpenAI format"""
if stream:
if not response.get("candidates"):
return None
try:
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if not parts:
return None
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = self.format_code_block(parts[0]["executableCode"])
elif "executableCodeResult" in parts[0]:
text = self.format_execution_result(parts[0]["executableCodeResult"])
else:
text = ""
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": text},
"finish_reason": None,
}
],
}
except Exception as e:
logger.error(f"Error converting Gemini response: {str(e)}")
logger.debug(f"Raw response: {response}")
return None
else:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response["candidates"][0]["content"]["parts"][0][
"text"
],
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
async def create_chat_completion(
self,
messages: list,
model: str,
temperature: float,
stream: bool,
api_key: str,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Create chat completion using either Gemini or OpenAI API"""
if tools is None:
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED:
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return await self._gemini_chat_completion(
messages, model, temperature, stream, api_key, tools
)
# else:
# return await self._openai_chat_completion(
# messages, model, temperature, stream, api_key, tools
# )
async def _gemini_chat_completion(
self,
messages: list,
model: str,
temperature: float,
stream: bool,
api_key: str,
tools: Optional[list] = None,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Handle Gemini API chat completion"""
if model.endswith("-search"):
gemini_model = model[:-7] # Remove -search suffix
else:
gemini_model = model
gemini_messages = self.convert_messages_to_gemini_format(messages)
payload = {
"contents": gemini_messages,
"generationConfig": {"temperature": temperature},
"tools": tools,
}
if stream:
async def generate():
retries = 0
MAX_RETRIES = 3
current_api_key = api_key
while retries < MAX_RETRIES:
try:
async with httpx.AsyncClient() as client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
async with client.stream(
"POST", stream_url, json=payload
) as response:
if response.status_code != 200:
if retries < MAX_RETRIES - 1:
logger.warning(
f"API error: {response.status_code}, attempting retry {retries + 1}"
)
current_api_key = (
await self.key_manager.handle_api_failure(
current_api_key
)
)
logger.info(
f"Switched to new API key: {current_api_key}"
)
retries += 1
continue
else:
logger.error(
f"Max retries reached. Final error: {response.status_code}"
)
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
return
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
openai_chunk = (
self.convert_gemini_response_to_openai(
chunk, model, stream=True
)
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
return # 成功完成,退出重试循环
except Exception as e:
if retries < MAX_RETRIES - 1:
logger.warning(
f"Stream error: {str(e)}, attempting retry {retries + 1}"
)
current_api_key = await self.key_manager.handle_api_failure(
current_api_key
)
retries += 1
continue
else:
logger.error(f"Max retries reached. Final error: {str(e)}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return
return generate()
else:
async with httpx.AsyncClient() as client:
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
gemini_response = response.json()
return self.convert_gemini_response_to_openai(gemini_response, model)
async def _openai_chat_completion(
self,
messages: list,
model: str,
temperature: float,
stream: bool,
api_key: str,
tools: Optional[list] = None,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Handle OpenAI API chat completion"""
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
if tools:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
stream=stream,
tools=tools,
)
else:
response = client.chat.completions.create(
model=model, messages=messages, temperature=temperature, stream=stream
)
if stream:
async def generate():
for chunk in response:
yield f"data: {chunk.model_dump_json()}\n\n"
return generate()
return response
def format_code_block(self, code_data: dict) -> str:
"""格式化代码块输出"""
language = code_data.get("language", "").lower()
code = code_data.get("code", "").strip()
return f"""\n```{language}\n{code}\n```\n"""
def format_execution_result(result_data: dict) -> str:
"""格式化执行结果输出"""
outcome = result_data.get("outcome", "")
output = result_data.get("output", "").strip()
return f"""\n【执行结果】\n{output}\n"""

View File

@@ -0,0 +1,22 @@
import logging
import openai
from typing import Union, List, Dict, Any
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self, base_url: str):
self.base_url = base_url
async def create_embedding(
self, input_text: Union[str, List[str]], model: str, api_key: str
) -> Dict[str, Any]:
"""Create embeddings using OpenAI API"""
try:
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
response = client.embeddings.create(input=input_text, model=model)
return response
except Exception as e:
logger.error(f"Error creating embedding: {str(e)}")
raise

View File

@@ -0,0 +1,57 @@
import asyncio
from itertools import cycle
import logging
from typing import Dict
logger = logging.getLogger(__name__)
class KeyManager:
def __init__(self, api_keys: list):
self.api_keys = api_keys
self.key_cycle = cycle(api_keys)
self.key_cycle_lock = asyncio.Lock()
self.failure_count_lock = asyncio.Lock()
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
self.MAX_FAILURES = 10
async def get_next_key(self) -> str:
"""获取下一个API key"""
async with self.key_cycle_lock:
return next(self.key_cycle)
async def is_key_valid(self, key: str) -> bool:
"""检查key是否有效"""
async with self.failure_count_lock:
return self.key_failure_counts[key] < self.MAX_FAILURES
async def reset_failure_counts(self):
"""重置所有key的失败计数"""
async with self.failure_count_lock:
for key in self.key_failure_counts:
self.key_failure_counts[key] = 0
async def get_next_working_key(self) -> str:
"""获取下一个可用的API key"""
initial_key = await self.get_next_key()
current_key = initial_key
while True:
if await self.is_key_valid(current_key):
return current_key
current_key = await self.get_next_key()
if current_key == initial_key:
await self.reset_failure_counts()
return current_key
async def handle_api_failure(self, api_key: str) -> str:
"""处理API调用失败"""
async with self.failure_count_lock:
self.key_failure_counts[api_key] += 1
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
logger.warning(
f"API key {api_key} has failed {self.MAX_FAILURES} times"
)
return await self.get_next_working_key()

View File

@@ -0,0 +1,55 @@
import requests
from datetime import datetime, timezone
from typing import Optional, Dict, Any
import logging
logger = logging.getLogger(__name__)
class ModelService:
def __init__(self, model_search: list):
self.model_search = model_search
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
base_url = "https://generativelanguage.googleapis.com/v1beta"
url = f"{base_url}/models?key={api_key}"
try:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
return self.convert_to_openai_models_format(gemini_models)
else:
logger.error(f"Error: {response.status_code}")
logger.error(response.text)
return None
except requests.RequestException as e:
logger.error(f"Request failed: {e}")
return None
def convert_to_openai_models_format(
self, gemini_models: Dict[str, Any]
) -> Dict[str, Any]:
openai_format = {"object": "list", "data": []}
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]
openai_model = {
"id": model_id,
"object": "model",
"created": int(datetime.now(timezone.utc).timestamp()),
"owned_by": "google",
"permission": [],
"root": model["name"],
"parent": None,
"success": True,
}
openai_format["data"].append(openai_model)
if model_id in self.model_search:
search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model)
return openai_format

0
app/utils/helpers.py Normal file
View File

416
main.py
View File

@@ -1,416 +0,0 @@
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import openai
from typing import List, Optional, Union
import logging
from itertools import cycle
import asyncio
import uvicorn
from app import config
import requests
from datetime import datetime, timezone
import json
import httpx
import uuid
import time
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
# 允许跨域
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API密钥配置
API_KEYS = config.settings.API_KEYS
# 创建一个循环迭代器
key_cycle = cycle(API_KEYS)
# 创建两个独立的锁
key_cycle_lock = asyncio.Lock()
failure_count_lock = asyncio.Lock()
# 添加key失败计数记录
key_failure_counts = {key: 0 for key in API_KEYS}
MAX_FAILURES = 10 # 最大失败次数阈值
MAX_RETRIES = 3 # 最大重试次数
async def get_next_key():
"""仅获取下一个key,不检查失败次数"""
async with key_cycle_lock:
return next(key_cycle)
async def is_key_valid(key):
"""检查key是否有效"""
async with failure_count_lock:
return key_failure_counts[key] < MAX_FAILURES
async def reset_failure_counts():
"""重置所有key的失败计数"""
async with failure_count_lock:
for key in key_failure_counts:
key_failure_counts[key] = 0
async def get_next_working_key():
"""获取下一个可用的API key"""
initial_key = await get_next_key()
current_key = initial_key
while True:
if await is_key_valid(current_key):
return current_key
current_key = await get_next_key()
if current_key == initial_key: # 已经循环了一圈
await reset_failure_counts()
return current_key
async def handle_api_failure(api_key):
"""处理API调用失败"""
async with failure_count_lock:
key_failure_counts[api_key] += 1
if key_failure_counts[api_key] >= MAX_FAILURES:
logger.warning(
f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key"
)
# 在锁外获取新的key
return await get_next_working_key()
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "gemini-1.5-flash-002"
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
tool_choice: Optional[str] = "auto"
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-004"
encoding_format: Optional[str] = "float"
async def verify_authorization(authorization: str = Header(None)):
if not authorization:
logger.error("Missing Authorization header")
raise HTTPException(status_code=401, detail="Missing Authorization header")
if not authorization.startswith("Bearer "):
logger.error("Invalid Authorization header format")
raise HTTPException(
status_code=401, detail="Invalid Authorization header format"
)
token = authorization.replace("Bearer ", "")
if token not in config.settings.ALLOWED_TOKENS:
logger.error("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
return token
def get_gemini_models(api_key):
base_url = "https://generativelanguage.googleapis.com/v1beta"
url = f"{base_url}/models?key={api_key}"
try:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
return convert_to_openai_models_format(gemini_models)
else:
print(f"Error: {response.status_code}")
print(response.text)
return None
except requests.RequestException as e:
print(f"Request failed: {e}")
return None
def convert_to_openai_models_format(gemini_models):
openai_format = {"object": "list", "data": []}
# 添加常规模型
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]
openai_model = {
"id": model_id,
"object": "model",
"created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
"owned_by": "google", # 假设所有Gemini模型都由Google拥有
"permission": [], # Gemini API可能没有直接对应的权限信息
"root": model["name"],
"parent": None, # Gemini API可能没有直接对应的父模型信息
}
openai_format["data"].append(openai_model)
# 如果模型在 MODEL_SEARCH 中,添加带 search 后缀的版本
if model_id in config.settings.MODEL_SEARCH:
search_model = openai_model.copy()
search_model["id"] = f"{model_id}-search"
openai_format["data"].append(search_model)
return openai_format
def convert_messages_to_gemini_format(messages):
"""Convert OpenAI message format to Gemini format"""
gemini_messages = []
for message in messages:
gemini_message = {
"role": "user" if message["role"] == "user" else "model",
"parts": [{"text": message["content"]}],
}
gemini_messages.append(gemini_message)
return gemini_messages
def convert_gemini_response_to_openai(response, model, stream=False):
"""Convert Gemini response to OpenAI format"""
if stream:
# 处理流式响应
chunk = response
if not chunk["candidates"]:
return None
return {
"id": "chatcmpl-" + str(uuid.uuid4()),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": chunk["candidates"][0]["content"]["parts"][0]["text"]
},
"finish_reason": None,
}
],
}
else:
# 处理普通响应
return {
"id": "chatcmpl-" + str(uuid.uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response["candidates"][0]["content"]["parts"][0][
"text"
],
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
@app.get("/v1/models")
@app.get("/hf/v1/models")
async def list_models(authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
response = get_gemini_models(api_key)
logger.info("Successfully retrieved models list")
return response
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/chat/completions")
@app.post("/hf/v1/chat/completions")
async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Chat completion request - Model: {request.model}")
retries = 0
while retries < MAX_RETRIES:
try:
logger.info(f"Attempt {retries + 1} with API key: {api_key}")
# 修改判断条件,检查模型名是否以 -search 结尾
if request.model.endswith("-search"):
# 去掉 -search 后缀
gemini_model = request.model[:-7]
# Gemini API调用部分
gemini_messages = convert_messages_to_gemini_format(request.messages)
# 调用Gemini API
payload = {
"contents": gemini_messages,
"generationConfig": {
"temperature": request.temperature,
},
"tools": [{"googleSearch": {}}],
}
if request.stream:
logger.info("Streaming response enabled")
async def generate():
nonlocal api_key, retries
while retries < MAX_RETRIES:
try:
async with httpx.AsyncClient() as client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(
"POST", stream_url, json=payload
) as response:
if response.status_code == 429:
logger.warning(
f"Rate limit reached for key: {api_key}"
)
api_key = await handle_api_failure(api_key)
logger.info(
f"Retrying with new API key: {api_key}"
)
retries += 1
if retries >= MAX_RETRIES:
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
break
continue
if response.status_code != 200:
logger.error(
f"Error in streaming response: {response.status_code}"
)
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
break
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
openai_chunk = convert_gemini_response_to_openai(
chunk,
request.model,
stream=True,
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
return
except Exception as e:
logger.error(f"Stream error: {str(e)}")
api_key = await handle_api_failure(api_key)
retries += 1
if retries >= MAX_RETRIES:
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
break
continue
return StreamingResponse(
content=generate(), media_type="text/event-stream"
)
else:
# 非流式响应
async with httpx.AsyncClient() as client:
non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
response = await client.post(non_stream_url, json=payload)
gemini_response = response.json()
logger.info("Chat completion successful")
return convert_gemini_response_to_openai(
gemini_response, request.model
)
# OpenAI API调用部分
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
response = client.chat.completions.create(
model=request.model,
messages=request.messages,
temperature=request.temperature,
stream=request.stream if hasattr(request, "stream") else False,
)
if hasattr(request, "stream") and request.stream:
logger.info("Streaming response enabled")
async def generate():
for chunk in response:
yield f"data: {chunk.model_dump_json()}\n\n"
logger.info("Chat completion successful")
return StreamingResponse(
content=generate(), media_type="text/event-stream"
)
logger.info("Chat completion successful")
return response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}")
api_key = await handle_api_failure(api_key)
retries += 1
if retries >= MAX_RETRIES:
logger.error("Max retries reached, giving up")
raise HTTPException(
status_code=500,
detail="Max retries reached with all available API keys",
)
logger.info(f"Retrying with new API key: {api_key}")
continue
raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
@app.post("/v1/embeddings")
@app.post("/hf/v1/embeddings")
async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
response = client.embeddings.create(input=request.input, model=request.model)
logger.info("Embedding successful")
return response
except Exception as e:
logger.error(f"Error in embedding: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)