mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-01 05:39:47 +08:00
添加API密钥管理、模型服务和安全服务,并优化FastAPI应用程序配置
This commit is contained in:
8
.vscode/launch.json
vendored
8
.vscode/launch.json
vendored
@@ -10,8 +10,12 @@
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"main:app",
|
||||
"--reload"
|
||||
"app.main:app",
|
||||
"--reload",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"8000"
|
||||
],
|
||||
"jinja": true
|
||||
}
|
||||
|
||||
0
app/api/dependencies.py
Normal file
0
app/api/dependencies.py
Normal file
98
app/api/routes.py
Normal file
98
app/api/routes.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from typing import Optional
|
||||
import logging
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.security import SecurityService
|
||||
from app.services.key_manager import KeyManager
|
||||
from app.services.model_service import ModelService
|
||||
from app.services.chat_service import ChatService
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
from app.schemas.request_model import ChatRequest, EmbeddingRequest
|
||||
from app.core.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS)
|
||||
key_manager = KeyManager(settings.API_KEYS)
|
||||
model_service = ModelService(settings.MODEL_SEARCH)
|
||||
chat_service = ChatService(settings.BASE_URL, key_manager)
|
||||
embedding_service = EmbeddingService(settings.BASE_URL)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
authorization: str = Header(None),
|
||||
token: str = Depends(security_service.verify_authorization),
|
||||
):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return model_service.get_gemini_models(api_key)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/hf/v1/chat/completions")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
authorization: str = Header(None),
|
||||
token: str = Depends(security_service.verify_authorization),
|
||||
):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
retries = 0
|
||||
MAX_RETRIES = 3
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
response = await chat_service.create_chat_completion(
|
||||
messages=request.messages,
|
||||
model=request.model,
|
||||
temperature=request.temperature,
|
||||
stream=request.stream,
|
||||
api_key=api_key,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
|
||||
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}"
|
||||
)
|
||||
api_key = await key_manager.handle_api_failure(api_key)
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
retries += 1
|
||||
if retries >= MAX_RETRIES:
|
||||
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@router.post("/hf/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
authorization: str = Header(None),
|
||||
token: str = Depends(security_service.verify_authorization),
|
||||
):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
response = await embedding_service.create_embedding(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
logger.info("Embedding request successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding request failed: {str(e)}")
|
||||
raise
|
||||
@@ -1,20 +0,0 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
class Settings(BaseSettings):
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
BASE_URL: str
|
||||
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = True
|
||||
# 同时从环境变量和.env文件获取配置
|
||||
env_nested_delimiter = "__"
|
||||
extra = "ignore"
|
||||
|
||||
# 优先从环境变量获取,如果没有则从.env文件获取
|
||||
settings = Settings(_env_file=os.getenv("ENV_FILE", ".env"))
|
||||
16
app/core/config.py
Normal file
16
app/core/config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import List
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
|
||||
MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
30
app/core/security.py
Normal file
30
app/core/security.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from fastapi import HTTPException, Header
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityService:
|
||||
def __init__(self, allowed_tokens: list):
|
||||
self.allowed_tokens = allowed_tokens
|
||||
|
||||
async def verify_authorization(
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
if not authorization:
|
||||
logger.error("Missing Authorization header")
|
||||
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
||||
|
||||
if not authorization.startswith("Bearer "):
|
||||
logger.error("Invalid Authorization header format")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid Authorization header format"
|
||||
)
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token not in self.allowed_tokens:
|
||||
logger.error("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
return token
|
||||
38
app/main.py
Normal file
38
app/main.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import logging
|
||||
|
||||
from app.api.routes import router
|
||||
import uvicorn
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 允许跨域
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 包含所有路由
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/")
|
||||
async def health_check():
|
||||
logger.info("Health check endpoint called")
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
17
app/schemas/request_model.py
Normal file
17
app/schemas/request_model.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str = "gemini-1.5-flash-002"
|
||||
temperature: Optional[float] = 0.7
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[List[dict]] = []
|
||||
tool_choice: Optional[str] = "auto"
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-004"
|
||||
encoding_format: Optional[str] = "float"
|
||||
0
app/schemas/response_model.py
Normal file
0
app/schemas/response_model.py
Normal file
299
app/services/chat_service.py
Normal file
299
app/services/chat_service.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import httpx
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, AsyncGenerator, Union
|
||||
import openai
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatService:
|
||||
def __init__(self, base_url: str, key_manager=None):
|
||||
self.base_url = base_url
|
||||
self.key_manager = key_manager
|
||||
|
||||
def convert_messages_to_gemini_format(self, messages: list) -> list:
|
||||
"""Convert OpenAI message format to Gemini format"""
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
parts = []
|
||||
|
||||
# 处理文本内容
|
||||
if isinstance(msg["content"], str):
|
||||
parts.append({"text": msg["content"]})
|
||||
# 处理包含图片的消息
|
||||
elif isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str):
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, dict) and content["type"] == "text":
|
||||
parts.append({"text": content["text"]})
|
||||
elif isinstance(content, dict) and content["type"] == "image_url":
|
||||
# 处理图片URL
|
||||
image_url = content["image_url"]["url"]
|
||||
if image_url.startswith("data:image"):
|
||||
# 处理base64图片
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": image_url.split(",")[1],
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 处理普通URL图片
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": image_url,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
return converted_messages
|
||||
|
||||
def convert_gemini_response_to_openai(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Convert Gemini response to OpenAI format"""
|
||||
if stream:
|
||||
if not response.get("candidates"):
|
||||
return None
|
||||
|
||||
try:
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
elif "executableCode" in parts[0]:
|
||||
text = self.format_code_block(parts[0]["executableCode"])
|
||||
elif "executableCodeResult" in parts[0]:
|
||||
text = self.format_execution_result(parts[0]["executableCodeResult"])
|
||||
else:
|
||||
text = ""
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": text},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting Gemini response: {str(e)}")
|
||||
logger.debug(f"Raw response: {response}")
|
||||
return None
|
||||
else:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0][
|
||||
"text"
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
stream: bool,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Create chat completion using either Gemini or OpenAI API"""
|
||||
|
||||
if tools is None:
|
||||
tools = []
|
||||
if settings.TOOLS_CODE_EXECUTION_ENABLED:
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
tools.append({"googleSearch": {}})
|
||||
return await self._gemini_chat_completion(
|
||||
messages, model, temperature, stream, api_key, tools
|
||||
)
|
||||
# else:
|
||||
# return await self._openai_chat_completion(
|
||||
# messages, model, temperature, stream, api_key, tools
|
||||
# )
|
||||
|
||||
async def _gemini_chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
stream: bool,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Handle Gemini API chat completion"""
|
||||
if model.endswith("-search"):
|
||||
gemini_model = model[:-7] # Remove -search suffix
|
||||
else:
|
||||
gemini_model = model
|
||||
gemini_messages = self.convert_messages_to_gemini_format(messages)
|
||||
|
||||
payload = {
|
||||
"contents": gemini_messages,
|
||||
"generationConfig": {"temperature": temperature},
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
if stream:
|
||||
|
||||
async def generate():
|
||||
retries = 0
|
||||
MAX_RETRIES = 3
|
||||
current_api_key = api_key
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
|
||||
async with client.stream(
|
||||
"POST", stream_url, json=payload
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
if retries < MAX_RETRIES - 1:
|
||||
logger.warning(
|
||||
f"API error: {response.status_code}, attempting retry {retries + 1}"
|
||||
)
|
||||
current_api_key = (
|
||||
await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Switched to new API key: {current_api_key}"
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"Max retries reached. Final error: {response.status_code}"
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
|
||||
return
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
openai_chunk = (
|
||||
self.convert_gemini_response_to_openai(
|
||||
chunk, model, stream=True
|
||||
)
|
||||
)
|
||||
if openai_chunk:
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
yield "data: [DONE]\n\n"
|
||||
return # 成功完成,退出重试循环
|
||||
|
||||
except Exception as e:
|
||||
if retries < MAX_RETRIES - 1:
|
||||
logger.warning(
|
||||
f"Stream error: {str(e)}, attempting retry {retries + 1}"
|
||||
)
|
||||
current_api_key = await self.key_manager.handle_api_failure(
|
||||
current_api_key
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries reached. Final error: {str(e)}")
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
return
|
||||
|
||||
return generate()
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
gemini_response = response.json()
|
||||
return self.convert_gemini_response_to_openai(gemini_response, model)
|
||||
|
||||
async def _openai_chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
stream: bool,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Handle OpenAI API chat completion"""
|
||||
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
if tools:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
response = client.chat.completions.create(
|
||||
model=model, messages=messages, temperature=temperature, stream=stream
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
||||
async def generate():
|
||||
for chunk in response:
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
return generate()
|
||||
|
||||
return response
|
||||
|
||||
def format_code_block(self, code_data: dict) -> str:
|
||||
"""格式化代码块输出"""
|
||||
language = code_data.get("language", "").lower()
|
||||
code = code_data.get("code", "").strip()
|
||||
|
||||
return f"""\n```{language}\n{code}\n```\n"""
|
||||
|
||||
|
||||
def format_execution_result(result_data: dict) -> str:
|
||||
"""格式化执行结果输出"""
|
||||
outcome = result_data.get("outcome", "")
|
||||
output = result_data.get("output", "").strip()
|
||||
return f"""\n【执行结果】\n{output}\n"""
|
||||
22
app/services/embedding_service.py
Normal file
22
app/services/embedding_service.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import logging
|
||||
import openai
|
||||
from typing import Union, List, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def create_embedding(
|
||||
self, input_text: Union[str, List[str]], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create embeddings using OpenAI API"""
|
||||
try:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(input=input_text, model=model)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating embedding: {str(e)}")
|
||||
raise
|
||||
57
app/services/key_manager.py
Normal file
57
app/services/key_manager.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyManager:
|
||||
def __init__(self, api_keys: list):
|
||||
self.api_keys = api_keys
|
||||
self.key_cycle = cycle(api_keys)
|
||||
self.key_cycle_lock = asyncio.Lock()
|
||||
self.failure_count_lock = asyncio.Lock()
|
||||
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
|
||||
self.MAX_FAILURES = 10
|
||||
|
||||
async def get_next_key(self) -> str:
|
||||
"""获取下一个API key"""
|
||||
async with self.key_cycle_lock:
|
||||
return next(self.key_cycle)
|
||||
|
||||
async def is_key_valid(self, key: str) -> bool:
|
||||
"""检查key是否有效"""
|
||||
async with self.failure_count_lock:
|
||||
return self.key_failure_counts[key] < self.MAX_FAILURES
|
||||
|
||||
async def reset_failure_counts(self):
|
||||
"""重置所有key的失败计数"""
|
||||
async with self.failure_count_lock:
|
||||
for key in self.key_failure_counts:
|
||||
self.key_failure_counts[key] = 0
|
||||
|
||||
async def get_next_working_key(self) -> str:
|
||||
"""获取下一个可用的API key"""
|
||||
initial_key = await self.get_next_key()
|
||||
current_key = initial_key
|
||||
|
||||
while True:
|
||||
if await self.is_key_valid(current_key):
|
||||
return current_key
|
||||
|
||||
current_key = await self.get_next_key()
|
||||
if current_key == initial_key:
|
||||
await self.reset_failure_counts()
|
||||
return current_key
|
||||
|
||||
async def handle_api_failure(self, api_key: str) -> str:
|
||||
"""处理API调用失败"""
|
||||
async with self.failure_count_lock:
|
||||
self.key_failure_counts[api_key] += 1
|
||||
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
)
|
||||
|
||||
return await self.get_next_working_key()
|
||||
55
app/services/model_service.py
Normal file
55
app/services/model_service.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelService:
|
||||
def __init__(self, model_search: list):
|
||||
self.model_search = model_search
|
||||
|
||||
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
url = f"{base_url}/models?key={api_key}"
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
return self.convert_to_openai_models_format(gemini_models)
|
||||
else:
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(response.text)
|
||||
return None
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
def convert_to_openai_models_format(
|
||||
self, gemini_models: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
openai_format = {"object": "list", "data": []}
|
||||
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
openai_model = {
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": int(datetime.now(timezone.utc).timestamp()),
|
||||
"owned_by": "google",
|
||||
"permission": [],
|
||||
"root": model["name"],
|
||||
"parent": None,
|
||||
"success": True,
|
||||
}
|
||||
openai_format["data"].append(openai_model)
|
||||
|
||||
if model_id in self.model_search:
|
||||
search_model = openai_model.copy()
|
||||
search_model["id"] = f"{model_id}-search"
|
||||
openai_format["data"].append(search_model)
|
||||
|
||||
return openai_format
|
||||
0
app/utils/helpers.py
Normal file
0
app/utils/helpers.py
Normal file
416
main.py
416
main.py
@@ -1,416 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import openai
|
||||
from typing import List, Optional, Union
|
||||
import logging
|
||||
from itertools import cycle
|
||||
import asyncio
|
||||
|
||||
import uvicorn
|
||||
|
||||
from app import config
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import httpx
|
||||
import uuid
|
||||
import time
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 允许跨域
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# API密钥配置
|
||||
API_KEYS = config.settings.API_KEYS
|
||||
|
||||
# 创建一个循环迭代器
|
||||
key_cycle = cycle(API_KEYS)
|
||||
|
||||
# 创建两个独立的锁
|
||||
key_cycle_lock = asyncio.Lock()
|
||||
failure_count_lock = asyncio.Lock()
|
||||
|
||||
# 添加key失败计数记录
|
||||
key_failure_counts = {key: 0 for key in API_KEYS}
|
||||
MAX_FAILURES = 10 # 最大失败次数阈值
|
||||
MAX_RETRIES = 3 # 最大重试次数
|
||||
|
||||
|
||||
async def get_next_key():
|
||||
"""仅获取下一个key,不检查失败次数"""
|
||||
async with key_cycle_lock:
|
||||
return next(key_cycle)
|
||||
|
||||
|
||||
async def is_key_valid(key):
|
||||
"""检查key是否有效"""
|
||||
async with failure_count_lock:
|
||||
return key_failure_counts[key] < MAX_FAILURES
|
||||
|
||||
|
||||
async def reset_failure_counts():
|
||||
"""重置所有key的失败计数"""
|
||||
async with failure_count_lock:
|
||||
for key in key_failure_counts:
|
||||
key_failure_counts[key] = 0
|
||||
|
||||
|
||||
async def get_next_working_key():
|
||||
"""获取下一个可用的API key"""
|
||||
initial_key = await get_next_key()
|
||||
current_key = initial_key
|
||||
|
||||
while True:
|
||||
if await is_key_valid(current_key):
|
||||
return current_key
|
||||
|
||||
current_key = await get_next_key()
|
||||
if current_key == initial_key: # 已经循环了一圈
|
||||
await reset_failure_counts()
|
||||
return current_key
|
||||
|
||||
|
||||
async def handle_api_failure(api_key):
|
||||
"""处理API调用失败"""
|
||||
async with failure_count_lock:
|
||||
key_failure_counts[api_key] += 1
|
||||
if key_failure_counts[api_key] >= MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key"
|
||||
)
|
||||
|
||||
# 在锁外获取新的key
|
||||
return await get_next_working_key()
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str = "gemini-1.5-flash-002"
|
||||
temperature: Optional[float] = 0.7
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[List[dict]] = []
|
||||
tool_choice: Optional[str] = "auto"
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[str, List[str]]
|
||||
model: str = "text-embedding-004"
|
||||
encoding_format: Optional[str] = "float"
|
||||
|
||||
|
||||
async def verify_authorization(authorization: str = Header(None)):
|
||||
if not authorization:
|
||||
logger.error("Missing Authorization header")
|
||||
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
||||
if not authorization.startswith("Bearer "):
|
||||
logger.error("Invalid Authorization header format")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid Authorization header format"
|
||||
)
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token not in config.settings.ALLOWED_TOKENS:
|
||||
logger.error("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
return token
|
||||
|
||||
|
||||
def get_gemini_models(api_key):
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
url = f"{base_url}/models?key={api_key}"
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
return convert_to_openai_models_format(gemini_models)
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
return None
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_to_openai_models_format(gemini_models):
|
||||
openai_format = {"object": "list", "data": []}
|
||||
|
||||
# 添加常规模型
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
openai_model = {
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
|
||||
"owned_by": "google", # 假设所有Gemini模型都由Google拥有
|
||||
"permission": [], # Gemini API可能没有直接对应的权限信息
|
||||
"root": model["name"],
|
||||
"parent": None, # Gemini API可能没有直接对应的父模型信息
|
||||
}
|
||||
openai_format["data"].append(openai_model)
|
||||
|
||||
# 如果模型在 MODEL_SEARCH 中,添加带 search 后缀的版本
|
||||
if model_id in config.settings.MODEL_SEARCH:
|
||||
search_model = openai_model.copy()
|
||||
search_model["id"] = f"{model_id}-search"
|
||||
openai_format["data"].append(search_model)
|
||||
|
||||
return openai_format
|
||||
|
||||
|
||||
def convert_messages_to_gemini_format(messages):
|
||||
"""Convert OpenAI message format to Gemini format"""
|
||||
gemini_messages = []
|
||||
for message in messages:
|
||||
gemini_message = {
|
||||
"role": "user" if message["role"] == "user" else "model",
|
||||
"parts": [{"text": message["content"]}],
|
||||
}
|
||||
gemini_messages.append(gemini_message)
|
||||
return gemini_messages
|
||||
|
||||
|
||||
def convert_gemini_response_to_openai(response, model, stream=False):
|
||||
"""Convert Gemini response to OpenAI format"""
|
||||
if stream:
|
||||
# 处理流式响应
|
||||
chunk = response
|
||||
if not chunk["candidates"]:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": "chatcmpl-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk["candidates"][0]["content"]["parts"][0]["text"]
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
# 处理普通响应
|
||||
return {
|
||||
"id": "chatcmpl-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response["candidates"][0]["content"]["parts"][0][
|
||||
"text"
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
@app.get("/hf/v1/models")
|
||||
async def list_models(authorization: str = Header(None)):
|
||||
await verify_authorization(authorization)
|
||||
api_key = await get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
try:
|
||||
response = get_gemini_models(api_key)
|
||||
logger.info("Successfully retrieved models list")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing models: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@app.post("/hf/v1/chat/completions")
|
||||
async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
|
||||
await verify_authorization(authorization)
|
||||
api_key = await get_next_working_key()
|
||||
logger.info(f"Chat completion request - Model: {request.model}")
|
||||
retries = 0
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
logger.info(f"Attempt {retries + 1} with API key: {api_key}")
|
||||
|
||||
# 修改判断条件,检查模型名是否以 -search 结尾
|
||||
if request.model.endswith("-search"):
|
||||
# 去掉 -search 后缀
|
||||
gemini_model = request.model[:-7]
|
||||
|
||||
# Gemini API调用部分
|
||||
gemini_messages = convert_messages_to_gemini_format(request.messages)
|
||||
# 调用Gemini API
|
||||
payload = {
|
||||
"contents": gemini_messages,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
},
|
||||
"tools": [{"googleSearch": {}}],
|
||||
}
|
||||
|
||||
if request.stream:
|
||||
logger.info("Streaming response enabled")
|
||||
|
||||
async def generate():
|
||||
nonlocal api_key, retries
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(
|
||||
"POST", stream_url, json=payload
|
||||
) as response:
|
||||
if response.status_code == 429:
|
||||
logger.warning(
|
||||
f"Rate limit reached for key: {api_key}"
|
||||
)
|
||||
api_key = await handle_api_failure(api_key)
|
||||
logger.info(
|
||||
f"Retrying with new API key: {api_key}"
|
||||
)
|
||||
retries += 1
|
||||
if retries >= MAX_RETRIES:
|
||||
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
|
||||
break
|
||||
continue
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Error in streaming response: {response.status_code}"
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
|
||||
break
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
openai_chunk = convert_gemini_response_to_openai(
|
||||
chunk,
|
||||
request.model,
|
||||
stream=True,
|
||||
)
|
||||
if openai_chunk:
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error: {str(e)}")
|
||||
api_key = await handle_api_failure(api_key)
|
||||
retries += 1
|
||||
if retries >= MAX_RETRIES:
|
||||
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
|
||||
break
|
||||
continue
|
||||
|
||||
return StreamingResponse(
|
||||
content=generate(), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式响应
|
||||
async with httpx.AsyncClient() as client:
|
||||
non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
|
||||
response = await client.post(non_stream_url, json=payload)
|
||||
gemini_response = response.json()
|
||||
logger.info("Chat completion successful")
|
||||
return convert_gemini_response_to_openai(
|
||||
gemini_response, request.model
|
||||
)
|
||||
|
||||
# OpenAI API调用部分
|
||||
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
|
||||
response = client.chat.completions.create(
|
||||
model=request.model,
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
stream=request.stream if hasattr(request, "stream") else False,
|
||||
)
|
||||
|
||||
if hasattr(request, "stream") and request.stream:
|
||||
logger.info("Streaming response enabled")
|
||||
|
||||
async def generate():
|
||||
for chunk in response:
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
logger.info("Chat completion successful")
|
||||
return StreamingResponse(
|
||||
content=generate(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
logger.info("Chat completion successful")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat completion: {str(e)}")
|
||||
api_key = await handle_api_failure(api_key)
|
||||
retries += 1
|
||||
|
||||
if retries >= MAX_RETRIES:
|
||||
logger.error("Max retries reached, giving up")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Max retries reached with all available API keys",
|
||||
)
|
||||
|
||||
logger.info(f"Retrying with new API key: {api_key}")
|
||||
continue
|
||||
|
||||
raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
@app.post("/hf/v1/embeddings")
|
||||
async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
|
||||
await verify_authorization(authorization)
|
||||
api_key = await get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
try:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
|
||||
response = client.embeddings.create(input=request.input, model=request.model)
|
||||
logger.info("Embedding successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Error in embedding: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/")
|
||||
async def health_check():
|
||||
logger.info("Health check endpoint called")
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user