mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 14:21:27 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
84052a2179 | ||
|
|
2e7ecd88b5 | ||
|
|
0b1f3dfc04 | ||
|
|
c691c7c1cf |
@@ -72,3 +72,22 @@ class SecurityService:
|
|||||||
raise HTTPException(status_code=401, detail="Invalid auth_token")
|
raise HTTPException(status_code=401, detail="Invalid auth_token")
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
async def verify_key_or_goog_api_key(
|
||||||
|
self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None)
|
||||||
|
) -> str:
|
||||||
|
"""验证URL中的key或请求头中的x-goog-api-key"""
|
||||||
|
# 如果URL中的key有效,直接返回
|
||||||
|
if key in self.allowed_tokens or key == self.auth_token:
|
||||||
|
return key
|
||||||
|
|
||||||
|
# 否则检查请求头中的x-goog-api-key
|
||||||
|
if not x_goog_api_key:
|
||||||
|
logger.error("Invalid key and missing x-goog-api-key header")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
|
||||||
|
|
||||||
|
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
|
||||||
|
logger.error("Invalid key and invalid x-goog-api-key")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
|
||||||
|
|
||||||
|
return x_goog_api_key
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Dict, Any, Literal
|
from typing import List, Optional, Dict, Any, Literal, Union
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ class GeminiContent(BaseModel):
|
|||||||
|
|
||||||
class GeminiRequest(BaseModel):
|
class GeminiRequest(BaseModel):
|
||||||
contents: List[GeminiContent] = []
|
contents: List[GeminiContent] = []
|
||||||
tools: Optional[List[Dict[str, Any]]] = []
|
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
||||||
safetySettings: Optional[List[SafetySetting]] = None
|
safetySettings: Optional[List[SafetySetting]] = None
|
||||||
generationConfig: Optional[GenerationConfig] = None
|
generationConfig: Optional[GenerationConfig] = None
|
||||||
systemInstruction: Optional[SystemInstruction] = None
|
systemInstruction: Optional[SystemInstruction] = None
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager
|
|||||||
@router.get("/models")
|
@router.get("/models")
|
||||||
@router_v1beta.get("/models")
|
@router_v1beta.get("/models")
|
||||||
async def list_models(
|
async def list_models(
|
||||||
_=Depends(security_service.verify_key),
|
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
):
|
):
|
||||||
"""获取可用的Gemini模型列表"""
|
"""获取可用的Gemini模型列表"""
|
||||||
@@ -86,7 +86,7 @@ async def list_models(
|
|||||||
async def generate_content(
|
async def generate_content(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
request: GeminiRequest,
|
request: GeminiRequest,
|
||||||
_=Depends(security_service.verify_goog_api_key),
|
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||||
api_key: str = Depends(get_next_working_key),
|
api_key: str = Depends(get_next_working_key),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
):
|
):
|
||||||
@@ -118,7 +118,7 @@ async def generate_content(
|
|||||||
async def stream_generate_content(
|
async def stream_generate_content(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
request: GeminiRequest,
|
request: GeminiRequest,
|
||||||
_=Depends(security_service.verify_goog_api_key),
|
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||||
api_key: str = Depends(get_next_working_key),
|
api_key: str = Depends(get_next_working_key),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
tool = dict()
|
tool = dict()
|
||||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||||
|
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
||||||
|
payload["tools"] = [payload.get("tools")]
|
||||||
items = payload.get("tools", [])
|
items = payload.get("tools", [])
|
||||||
if items and isinstance(items, list):
|
if items and isinstance(items, list):
|
||||||
tool.update(_merge_tools(items))
|
tool.update(_merge_tools(items))
|
||||||
@@ -62,7 +64,7 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|||||||
tool.pop("googleSearch", None)
|
tool.pop("googleSearch", None)
|
||||||
tool.pop("codeExecution", None)
|
tool.pop("codeExecution", None)
|
||||||
|
|
||||||
return [tool]
|
return [tool] if tool else []
|
||||||
|
|
||||||
|
|
||||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||||
|
|||||||
Reference in New Issue
Block a user