mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-07 06:32:43 +08:00
feat: 将 Gemini API 调用迁移至 ChatService 并支持 API Key 验证
This commit is contained in:
@@ -17,6 +17,7 @@ logger = get_gemini_logger()
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS)
|
||||
key_manager = KeyManager(settings.API_KEYS)
|
||||
model_service = ModelService(settings.MODEL_SEARCH)
|
||||
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(
|
||||
@@ -32,13 +33,13 @@ async def list_models(
|
||||
|
||||
@router.post("/models/{model_name}:generateContent")
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
# key: str = None,
|
||||
# token: str = Depends(security_service.verify_key),
|
||||
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
|
||||
):
|
||||
"""非流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||
logger.info(f"Handling Gemini content generation request for model: {request.model}")
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
@@ -48,13 +49,9 @@ async def generate_content(
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
response = await model_service.generate_content(
|
||||
contents=request.contents,
|
||||
model=request.model,
|
||||
temperature=request.temperature,
|
||||
candidate_count=request.candidate_count,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
response = await chat_service.generate_content(
|
||||
model_name=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
@@ -68,14 +65,12 @@ async def generate_content(
|
||||
retries += 1
|
||||
if retries >= MAX_RETRIES:
|
||||
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
|
||||
raise
|
||||
|
||||
@router.post("/models/{model_name}:streamGenerateContent")
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
# x_goog_api_key: str = Header("x-goog-api-key"),
|
||||
# token: str = Depends(security_service.verify_key),
|
||||
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
|
||||
):
|
||||
"""流式生成内容"""
|
||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||
@@ -94,5 +89,4 @@ async def stream_generate_content(
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
raise
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
@@ -34,3 +34,15 @@ class SecurityService:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
return token
|
||||
|
||||
async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str:
|
||||
"""验证Google API Key"""
|
||||
if not x_goog_api_key:
|
||||
logger.error("Missing x-goog-api-key header")
|
||||
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
||||
|
||||
if x_goog_api_key not in self.allowed_tokens:
|
||||
logger.error("Invalid x-goog-api-key")
|
||||
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
||||
|
||||
return x_goog_api_key
|
||||
|
||||
@@ -3,7 +3,7 @@ from pydantic import BaseModel
|
||||
|
||||
class SafetySetting(BaseModel):
|
||||
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT"]] = None
|
||||
threshold: Optional[Literal["HARM_BLOCK", "HARM_FLAG", "HARM_UNSPECIFIED"]] = None
|
||||
threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE","BLOCK_ONLY_HIGH","BLOCK_NONE","OFF"]] = None
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
@@ -33,7 +33,7 @@ class GeminiContent(BaseModel):
|
||||
|
||||
class GeminiRequest(BaseModel):
|
||||
contents: List[GeminiContent]
|
||||
# tools: Optional[List[Dict[str, Any]]] = None
|
||||
# safetySettings: Optional[List[SafetySetting]] = None
|
||||
tools: Optional[List[Dict[str, Any]]] = []
|
||||
safetySettings: Optional[List[SafetySetting]] = None
|
||||
generationConfig: Optional[GenerationConfig] = None
|
||||
# systemInstruction: Optional[SystemInstruction] = None
|
||||
systemInstruction: Optional[SystemInstruction] = None
|
||||
@@ -312,21 +312,45 @@ class ChatService:
|
||||
api_key: str
|
||||
) -> AsyncGenerator:
|
||||
"""调用Gemini API流式生成内容"""
|
||||
url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
|
||||
timeout = httpx.Timeout(60.0, read=60.0) # 连接超时60秒,读取超时60秒
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
retries = 0
|
||||
MAX_RETRIES = 3
|
||||
current_api_key = api_key
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
try:
|
||||
async with client.stream('POST', url, json=request.model_dump()) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Error: {response.status_code}")
|
||||
logger.error(error_text)
|
||||
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
print(line)
|
||||
yield line + "\n\n"
|
||||
url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={current_api_key}"
|
||||
timeout = httpx.Timeout(60.0, read=60.0)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream('POST', url, json=request.model_dump()) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Error: {response.status_code}: {error_text}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
raise Exception(f"API request failed with status {response.status_code}: {error_text}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
yield line + "\n\n"
|
||||
return
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
logger.warning(f"Read timeout occurred, attempting retry {retries + 1}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming request failed: {str(e)}")
|
||||
if retries < MAX_RETRIES - 1:
|
||||
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
|
||||
logger.info(f"Switched to new API key: {current_api_key}")
|
||||
retries += 1
|
||||
continue
|
||||
raise
|
||||
@@ -4,4 +4,5 @@ openai
|
||||
pydantic
|
||||
pydantic_settings
|
||||
requests
|
||||
starlette
|
||||
uvicorn
|
||||
Reference in New Issue
Block a user