feat: 支持 Gemini 格式请求,并优化日志和配置

This commit is contained in:
yinpeng
2024-12-18 19:54:43 +08:00
parent 1913a3c909
commit d9229cced9
14 changed files with 290 additions and 62 deletions

4
.vscode/launch.json vendored
View File

@@ -13,10 +13,10 @@
"app.main:app",
"--reload",
"--host",
"127.0.0.1",
"0.0.0.0",
"--port",
"8000",
"--no-access-log"
// "--no-access-log"
],
"jinja": true
}

View File

@@ -100,7 +100,7 @@ GET /health
## 📚 代码结构
```
```plaintext
.
├── app/
│ ├── api/

98
app/api/gemini_routes.py Normal file
View File

@@ -0,0 +1,98 @@
from email.header import Header
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from app.core.security import SecurityService
from app.services.chat_service import ChatService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.schemas.gemini_models import GeminiRequest
from app.core.config import settings
from app.core.logger import get_gemini_logger
router = APIRouter(prefix="/gemini/v1beta")
logger = get_gemini_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS)
key_manager = KeyManager(settings.API_KEYS)
model_service = ModelService(settings.MODEL_SEARCH)
@router.get("/models")
async def list_models(
key: str = None,
token: str = Depends(security_service.verify_key),
):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
return model_service.get_gemini_models(api_key)
@router.post("/models/{model_name}:generateContent")
async def generate_content(
request: GeminiRequest,
# key: str = None,
# token: str = Depends(security_service.verify_key),
):
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
retries = 0
MAX_RETRIES = 3
while retries < MAX_RETRIES:
try:
response = await model_service.generate_content(
contents=request.contents,
model=request.model,
temperature=request.temperature,
candidate_count=request.candidate_count,
top_p=request.top_p,
top_k=request.top_k,
api_key=api_key
)
return response
except Exception as e:
logger.warning(
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}"
)
api_key = await key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
retries += 1
if retries >= MAX_RETRIES:
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
raise
@router.post("/models/{model_name}:streamGenerateContent")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
# x_goog_api_key: str = Header("x-goog-api-key"),
# token: str = Depends(security_service.verify_key),
):
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
response_stream = chat_service.stream_generate_content(
model_name=model_name,
request=request,
api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")
raise

View File

@@ -7,12 +7,12 @@ from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat_service import ChatService
from app.services.embedding_service import EmbeddingService
from app.schemas.request_model import ChatRequest, EmbeddingRequest
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.core.config import settings
from app.core.logger import get_api_logger
from app.core.logger import get_openai_logger
router = APIRouter()
logger = get_api_logger()
logger = get_openai_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS)
@@ -32,7 +32,7 @@ async def list_models(
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
return model_service.get_gemini_models(api_key)
return model_service.get_gemini_openai_models(api_key)
@router.post("/v1/chat/completions")

View File

@@ -50,7 +50,7 @@ class Logger:
@staticmethod
def setup_logger(
name: str,
level: str = "info",
level: str = "debug",
) -> logging.Logger:
"""
设置并获取logger
@@ -83,8 +83,11 @@ class Logger:
return Logger._loggers.get(name)
# 预定义的loggers
def get_api_logger():
return Logger.setup_logger("api")
def get_openai_logger():
return Logger.setup_logger("openai")
def get_gemini_logger():
return Logger.setup_logger("gemini")
def get_chat_logger():
return Logger.setup_logger("chat")
@@ -103,3 +106,6 @@ def get_main_logger():
def get_embeddings_logger():
return Logger.setup_logger("embeddings")
def get_request_logger():
return Logger.setup_logger("request")

View File

@@ -9,6 +9,12 @@ class SecurityService:
def __init__(self, allowed_tokens: list):
self.allowed_tokens = allowed_tokens
async def verify_key(self, key: str):
if key not in self.allowed_tokens:
logger.error("Invalid key")
raise HTTPException(status_code=401, detail="Invalid key")
return key
async def verify_authorization(
self, authorization: Optional[str] = Header(None)
) -> str:

View File

@@ -2,25 +2,33 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.core.logger import get_main_logger
from app.api.routes import router
from app.api import gemini_routes, openai_routes
import uvicorn
from app.middleware.request_logging_middleware import RequestLoggingMiddleware
# 配置日志
logger = get_main_logger()
app = FastAPI()
# 允许跨域
# 添加请求日志中间件
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)
)
# 包含所有路由
app.include_router(router)
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
@app.get("/health")

View File

@@ -0,0 +1,36 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import json
from app.core.logger import get_request_logger
logger = get_request_logger()
# 添加中间件类
class RequestLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 记录请求路径
logger.info(f"Request path: {request.url.path}")
# 获取并记录请求体
try:
body = await request.body()
if body:
body_str = body.decode()
# 尝试格式化JSON
try:
formatted_body = json.loads(body_str)
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
except json.JSONDecodeError:
logger.info("Request body is not valid JSON.")
except Exception as e:
logger.error(f"Error reading request body: {str(e)}")
# 重置请求的接收器,以便后续处理器可以继续读取请求体
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
request._receive = receive
response = await call_next(request)
return response

View File

@@ -0,0 +1,39 @@
from typing import List, Optional, Dict, Any, Literal
from pydantic import BaseModel
class SafetySetting(BaseModel):
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT"]] = None
threshold: Optional[Literal["HARM_BLOCK", "HARM_FLAG", "HARM_UNSPECIFIED"]] = None
class GenerationConfig(BaseModel):
stopSequences: Optional[List[str]] = None
responseMimeType: Optional[str] = None
responseSchema: Optional[Dict[str, Any]] = None
candidateCount: Optional[int] = 1
maxOutputTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
topK: Optional[int] = None
presencePenalty: Optional[float] = None
frequencyPenalty: Optional[float] = None
responseLogprobs: Optional[bool] = None
logprobs: Optional[int] = None
class SystemInstruction(BaseModel):
role: str = "system"
parts: List[Dict[str, Any]]
class GeminiContent(BaseModel):
role: str
parts: List[Dict[str, Any]]
class GeminiRequest(BaseModel):
contents: List[GeminiContent]
# tools: Optional[List[Dict[str, Any]]] = None
# safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = None
# systemInstruction: Optional[SystemInstruction] = None

View File

@@ -3,9 +3,9 @@ import json
import time
import uuid
from typing import Dict, Any, Optional, AsyncGenerator, Union
import openai
from app.core.config import settings
from app.core.logger import get_chat_logger
from app.schemas.gemini_models import GeminiRequest
logger = get_chat_logger()
@@ -195,7 +195,8 @@ class ChatService:
while retries < MAX_RETRIES:
try:
async with httpx.AsyncClient() as client:
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒读取超时60秒
async with httpx.AsyncClient(timeout=timeout) as client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
async with client.stream("POST", stream_url, json=payload) as response:
if response.status_code != 200:
@@ -207,9 +208,9 @@ class ChatService:
continue
else:
logger.error(
f"Max retries reached. Final error: {response.status_code}"
f"Max retries reached. Final error: {response.status_code}, {error_msg}"
)
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
yield f"data: {json.dumps({'error': f'API error: {response.status_code}, {error_msg}'})}\n\n"
return
async for line in response.aiter_lines():
@@ -227,22 +228,34 @@ class ChatService:
yield "data: [DONE]\n\n"
return
except Exception as e:
logger.warning(f"Stream error: {str(e)}, attempting retry {retries + 1}")
except httpx.ReadTimeout:
logger.warning(f"Read timeout occurred, attempting retry {retries + 1}")
if retries < MAX_RETRIES - 1:
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
logger.info(f"Switched to new API key: {current_api_key}")
retries += 1
continue
else:
logger.error(f"Max retries reached. Final error: {str(e)}")
logger.error(f"Max retries reached. Final error: Read timeout")
yield f"data: {json.dumps({'error': 'Read timeout'})}\n\n"
return
except Exception as e:
logger.exception(f"Stream error: {str(e)}, attempting retry {retries + 1}")
if retries < MAX_RETRIES - 1:
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
logger.info(f"Switched to new API key: {current_api_key}")
retries += 1
continue
else:
logger.error(f"Max retries reached. Final error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return
return generate()
else:
try:
async with httpx.AsyncClient() as client:
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒读取超时60秒
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
if response.status_code != 200:
@@ -255,40 +268,6 @@ class ChatService:
logger.error(f"Error in non-stream completion")
raise
# async def _openai_chat_completion(
# self,
# messages: list,
# model: str,
# temperature: float,
# stream: bool,
# api_key: str,
# tools: Optional[list] = None,
# ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
# """Handle OpenAI API chat completion"""
# client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
# if tools:
# response = client.chat.completions.create(
# model=model,
# messages=messages,
# temperature=temperature,
# stream=stream,
# tools=tools,
# )
# else:
# response = client.chat.completions.create(
# model=model, messages=messages, temperature=temperature, stream=stream
# )
# if stream:
# async def generate():
# for chunk in response:
# yield f"data: {chunk.model_dump_json()}\n\n"
# return generate()
# return response
def format_code_block(self, code_data: dict) -> str:
"""格式化代码块输出"""
language = code_data.get("language", "").lower()
@@ -301,3 +280,53 @@ class ChatService:
outcome = result_data.get("outcome", "")
output = result_data.get("output", "").strip()
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n"""
async def generate_content(
self,
model_name: str,
request: GeminiRequest,
api_key: str
) -> dict:
"""调用Gemini API生成内容"""
url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}"
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒读取超时60秒
async with httpx.AsyncClient(timeout=timeout) as client:
try:
response = await client.post(url, json=request.model_dump())
if response.status_code == 200:
return response.json()
else:
error_text = response.text
logger.error(f"Error: {response.status_code}")
logger.error(error_text)
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
except Exception as e:
logger.error(f"Request failed: {str(e)}")
raise
async def stream_generate_content(
self,
model_name: str,
request: GeminiRequest,
api_key: str
) -> AsyncGenerator:
"""调用Gemini API流式生成内容"""
url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={api_key}"
timeout = httpx.Timeout(30.0, read=60.0) # 连接超时30秒读取超时60秒
async with httpx.AsyncClient(timeout=timeout) as client:
try:
async with client.stream('POST', url, json=request.model_dump()) as response:
if response.status_code != 200:
error_text = response.text
logger.error(f"Error: {response.status_code}")
logger.error(error_text)
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
async for line in response.aiter_lines():
print(line)
yield line + "\n\n"
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")
raise

View File

@@ -9,21 +9,28 @@ logger = get_model_logger()
class ModelService:
def __init__(self, model_search: list):
self.model_search = model_search
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
base_url = "https://generativelanguage.googleapis.com/v1beta"
url = f"{base_url}/models?key={api_key}"
url = f"{self.base_url}/models?key={api_key}"
try:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
return self.convert_to_openai_models_format(gemini_models)
return gemini_models
else:
logger.error(f"Error: {response.status_code}")
logger.error(response.text)
return None
except requests.RequestException as e:
logger.error(f"Request failed: {e}")
return None
def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
try:
gemini_models = self.get_gemini_models(api_key)
return self.convert_to_openai_models_format(gemini_models)
except requests.RequestException as e:
logger.error(f"Request failed: {e}")
return None
@@ -43,7 +50,6 @@ class ModelService:
"permission": [],
"root": model["name"],
"parent": None,
}
openai_format["data"].append(openai_model)