mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-22 16:59:33 +08:00
feat: 支持 Gemini 格式请求,并优化日志和配置
This commit is contained in:
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -13,10 +13,10 @@
|
||||
"app.main:app",
|
||||
"--reload",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"8000",
|
||||
"--no-access-log"
|
||||
// "--no-access-log"
|
||||
],
|
||||
"jinja": true
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ GET /health
|
||||
|
||||
## 📚 代码结构
|
||||
|
||||
```
|
||||
```plaintext
|
||||
.
|
||||
├── app/
|
||||
│ ├── api/
|
||||
|
||||
98
app/api/gemini_routes.py
Normal file
98
app/api/gemini_routes.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from email.header import Header
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.security import SecurityService
|
||||
from app.services.chat_service import ChatService
|
||||
from app.services.key_manager import KeyManager
|
||||
from app.services.model_service import ModelService
|
||||
from app.schemas.gemini_models import GeminiRequest
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_gemini_logger
|
||||
|
||||
router = APIRouter(prefix="/gemini/v1beta")
|
||||
logger = get_gemini_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS)
|
||||
key_manager = KeyManager(settings.API_KEYS)
|
||||
model_service = ModelService(settings.MODEL_SEARCH)
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(
|
||||
key: str = None,
|
||||
token: str = Depends(security_service.verify_key),
|
||||
):
|
||||
"""获取可用的Gemini模型列表"""
|
||||
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
||||
logger.info("Handling Gemini models list request")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return model_service.get_gemini_models(api_key)
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
async def generate_content(
|
||||
request: GeminiRequest,
|
||||
# key: str = None,
|
||||
# token: str = Depends(security_service.verify_key),
|
||||
):
|
||||
"""非流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini content generation request for model: {request.model}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
retries = 0
|
||||
MAX_RETRIES = 3
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
response = await model_service.generate_content(
|
||||
contents=request.contents,
|
||||
model=request.model,
|
||||
temperature=request.temperature,
|
||||
candidate_count=request.candidate_count,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}"
|
||||
)
|
||||
api_key = await key_manager.handle_api_failure(api_key)
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
retries += 1
|
||||
if retries >= MAX_RETRIES:
|
||||
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
|
||||
raise
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
# x_goog_api_key: str = Header("x-goog-api-key"),
|
||||
# token: str = Depends(security_service.verify_key),
|
||||
):
|
||||
"""流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
try:
|
||||
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model_name=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
raise
|
||||
@@ -7,12 +7,12 @@ from app.services.key_manager import KeyManager
|
||||
from app.services.model_service import ModelService
|
||||
from app.services.chat_service import ChatService
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
from app.schemas.request_model import ChatRequest, EmbeddingRequest
|
||||
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_api_logger
|
||||
from app.core.logger import get_openai_logger
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_api_logger()
|
||||
logger = get_openai_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS)
|
||||
@@ -32,7 +32,7 @@ async def list_models(
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
return model_service.get_gemini_models(api_key)
|
||||
return model_service.get_gemini_openai_models(api_key)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@@ -50,7 +50,7 @@ class Logger:
|
||||
@staticmethod
|
||||
def setup_logger(
|
||||
name: str,
|
||||
level: str = "info",
|
||||
level: str = "debug",
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
设置并获取logger
|
||||
@@ -83,8 +83,11 @@ class Logger:
|
||||
return Logger._loggers.get(name)
|
||||
|
||||
# 预定义的loggers
|
||||
def get_api_logger():
|
||||
return Logger.setup_logger("api")
|
||||
def get_openai_logger():
|
||||
return Logger.setup_logger("openai")
|
||||
|
||||
def get_gemini_logger():
|
||||
return Logger.setup_logger("gemini")
|
||||
|
||||
def get_chat_logger():
|
||||
return Logger.setup_logger("chat")
|
||||
@@ -103,3 +106,6 @@ def get_main_logger():
|
||||
|
||||
def get_embeddings_logger():
|
||||
return Logger.setup_logger("embeddings")
|
||||
|
||||
def get_request_logger():
|
||||
return Logger.setup_logger("request")
|
||||
@@ -9,6 +9,12 @@ class SecurityService:
|
||||
def __init__(self, allowed_tokens: list):
|
||||
self.allowed_tokens = allowed_tokens
|
||||
|
||||
async def verify_key(self, key: str):
|
||||
if key not in self.allowed_tokens:
|
||||
logger.error("Invalid key")
|
||||
raise HTTPException(status_code=401, detail="Invalid key")
|
||||
return key
|
||||
|
||||
async def verify_authorization(
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
|
||||
20
app/main.py
20
app/main.py
@@ -2,25 +2,33 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core.logger import get_main_logger
|
||||
|
||||
from app.api.routes import router
|
||||
from app.api import gemini_routes, openai_routes
|
||||
import uvicorn
|
||||
|
||||
from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
||||
|
||||
# 配置日志
|
||||
logger = get_main_logger()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 允许跨域
|
||||
# 添加请求日志中间件
|
||||
# app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
# 配置CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=["*"], # 生产环境建议配置具体的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
|
||||
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
||||
expose_headers=["*"], # 允许前端访问的响应头
|
||||
max_age=600, # 预检请求缓存时间(秒)
|
||||
)
|
||||
|
||||
# 包含所有路由
|
||||
app.include_router(router)
|
||||
app.include_router(openai_routes.router)
|
||||
app.include_router(gemini_routes.router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
36
app/middleware/request_logging_middleware.py
Normal file
36
app/middleware/request_logging_middleware.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import json
|
||||
from app.core.logger import get_request_logger
|
||||
|
||||
|
||||
logger = get_request_logger()
|
||||
|
||||
# 添加中间件类
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# 记录请求路径
|
||||
logger.info(f"Request path: {request.url.path}")
|
||||
|
||||
# 获取并记录请求体
|
||||
try:
|
||||
body = await request.body()
|
||||
if body:
|
||||
body_str = body.decode()
|
||||
# 尝试格式化JSON
|
||||
try:
|
||||
formatted_body = json.loads(body_str)
|
||||
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
|
||||
except json.JSONDecodeError:
|
||||
logger.info("Request body is not valid JSON.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading request body: {str(e)}")
|
||||
|
||||
# 重置请求的接收器,以便后续处理器可以继续读取请求体
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
request._receive = receive
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
39
app/schemas/gemini_models.py
Normal file
39
app/schemas/gemini_models.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SafetySetting(BaseModel):
|
||||
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT"]] = None
|
||||
threshold: Optional[Literal["HARM_BLOCK", "HARM_FLAG", "HARM_UNSPECIFIED"]] = None
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
stopSequences: Optional[List[str]] = None
|
||||
responseMimeType: Optional[str] = None
|
||||
responseSchema: Optional[Dict[str, Any]] = None
|
||||
candidateCount: Optional[int] = 1
|
||||
maxOutputTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
topK: Optional[int] = None
|
||||
presencePenalty: Optional[float] = None
|
||||
frequencyPenalty: Optional[float] = None
|
||||
responseLogprobs: Optional[bool] = None
|
||||
logprobs: Optional[int] = None
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
role: str = "system"
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
role: str
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiRequest(BaseModel):
|
||||
contents: List[GeminiContent]
|
||||
# tools: Optional[List[Dict[str, Any]]] = None
|
||||
# safetySettings: Optional[List[SafetySetting]] = None
|
||||
generationConfig: Optional[GenerationConfig] = None
|
||||
# systemInstruction: Optional[SystemInstruction] = None
|
||||
@@ -3,9 +3,9 @@ import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, AsyncGenerator, Union
|
||||
import openai
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_chat_logger
|
||||
from app.schemas.gemini_models import GeminiRequest
|
||||
|
||||
logger = get_chat_logger()
|
||||
|
||||
@@ -195,7 +195,8 @@ class ChatService:
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒,读取超时60秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
|
||||
async with client.stream("POST", stream_url, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
@@ -207,9 +208,9 @@ class ChatService:
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"Max retries reached. Final error: {response.status_code}"
|
||||
f"Max retries reached. Final error: {response.status_code}, {error_msg}"
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
|
||||
yield f"data: {json.dumps({'error': f'API error: {response.status_code}, {error_msg}'})}\n\n"
|
||||
return
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
@@ -227,22 +228,34 @@ class ChatService:
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Stream error: {str(e)}, attempting retry {retries + 1}")
|
||||
except httpx.ReadTimeout:
|
||||
logger.warning(f"Read timeout occurred, attempting retry {retries + 1}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries reached. Final error: {str(e)}")
|
||||
logger.error(f"Max retries reached. Final error: Read timeout")
|
||||
yield f"data: {json.dumps({'error': 'Read timeout'})}\n\n"
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Stream error: {str(e)}, attempting retry {retries + 1}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries reached. Final error: {e}")
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
return
|
||||
|
||||
return generate()
|
||||
else:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒,读取超时60秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
if response.status_code != 200:
|
||||
@@ -255,40 +268,6 @@ class ChatService:
|
||||
logger.error(f"Error in non-stream completion")
|
||||
raise
|
||||
|
||||
# async def _openai_chat_completion(
|
||||
# self,
|
||||
# messages: list,
|
||||
# model: str,
|
||||
# temperature: float,
|
||||
# stream: bool,
|
||||
# api_key: str,
|
||||
# tools: Optional[list] = None,
|
||||
# ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
# """Handle OpenAI API chat completion"""
|
||||
# client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
# if tools:
|
||||
# response = client.chat.completions.create(
|
||||
# model=model,
|
||||
# messages=messages,
|
||||
# temperature=temperature,
|
||||
# stream=stream,
|
||||
# tools=tools,
|
||||
# )
|
||||
# else:
|
||||
# response = client.chat.completions.create(
|
||||
# model=model, messages=messages, temperature=temperature, stream=stream
|
||||
# )
|
||||
|
||||
# if stream:
|
||||
|
||||
# async def generate():
|
||||
# for chunk in response:
|
||||
# yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# return generate()
|
||||
|
||||
# return response
|
||||
|
||||
def format_code_block(self, code_data: dict) -> str:
|
||||
"""格式化代码块输出"""
|
||||
language = code_data.get("language", "").lower()
|
||||
@@ -301,3 +280,53 @@ class ChatService:
|
||||
outcome = result_data.get("outcome", "")
|
||||
output = result_data.get("output", "").strip()
|
||||
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n"""
|
||||
|
||||
async def generate_content(
|
||||
self,
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
api_key: str
|
||||
) -> dict:
|
||||
"""调用Gemini API生成内容"""
|
||||
url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}"
|
||||
|
||||
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒,读取超时60秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
response = await client.post(url, json=request.model_dump())
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
error_text = response.text
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(error_text)
|
||||
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def stream_generate_content(
|
||||
self,
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
api_key: str
|
||||
) -> AsyncGenerator:
|
||||
"""调用Gemini API流式生成内容"""
|
||||
url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
|
||||
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒,读取超时60秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
async with client.stream('POST', url, json=request.model_dump()) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(error_text)
|
||||
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
print(line)
|
||||
yield line + "\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
raise
|
||||
@@ -9,21 +9,28 @@ logger = get_model_logger()
|
||||
class ModelService:
|
||||
def __init__(self, model_search: list):
|
||||
self.model_search = model_search
|
||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
url = f"{base_url}/models?key={api_key}"
|
||||
url = f"{self.base_url}/models?key={api_key}"
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
return self.convert_to_openai_models_format(gemini_models)
|
||||
return gemini_models
|
||||
else:
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(response.text)
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
return None
|
||||
|
||||
def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
gemini_models = self.get_gemini_models(api_key)
|
||||
return self.convert_to_openai_models_format(gemini_models)
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
return None
|
||||
@@ -43,7 +50,6 @@ class ModelService:
|
||||
"permission": [],
|
||||
"root": model["name"],
|
||||
"parent": None,
|
||||
|
||||
}
|
||||
openai_format["data"].append(openai_model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user