mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-30 04:01:39 +08:00
feat: 为API添加统一的鉴权token
This commit is contained in:
@@ -13,7 +13,7 @@ router = APIRouter(prefix="/gemini/v1beta")
|
||||
logger = get_gemini_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS)
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||
key_manager = KeyManager(settings.API_KEYS)
|
||||
model_service = ModelService(settings.MODEL_SEARCH)
|
||||
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
|
||||
|
||||
@@ -15,7 +15,7 @@ router = APIRouter()
|
||||
logger = get_openai_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS)
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||
key_manager = KeyManager(settings.API_KEYS)
|
||||
model_service = ModelService(settings.MODEL_SEARCH)
|
||||
chat_service = ChatService(settings.BASE_URL, key_manager)
|
||||
@@ -100,7 +100,7 @@ async def embedding(
|
||||
@router.get("/hf/v1/keys/list")
|
||||
async def get_keys_list(
|
||||
authorization: str = Header(None),
|
||||
token: str = Depends(security_service.verify_authorization),
|
||||
token: str = Depends(security_service.verify_auth_token),
|
||||
):
|
||||
"""获取有效和无效的API key列表"""
|
||||
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
||||
|
||||
@@ -10,7 +10,13 @@ class Settings(BaseSettings):
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
|
||||
AUTH_TOKEN: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not self.AUTH_TOKEN:
|
||||
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else ""
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ logger = get_security_logger()
|
||||
|
||||
|
||||
class SecurityService:
|
||||
def __init__(self, allowed_tokens: list):
|
||||
def __init__(self, allowed_tokens: list, auth_token: str):
|
||||
self.allowed_tokens = allowed_tokens
|
||||
self.auth_token = auth_token
|
||||
|
||||
async def verify_key(self, key: str):
|
||||
if key not in self.allowed_tokens:
|
||||
if key not in self.allowed_tokens and key != self.auth_token:
|
||||
logger.error("Invalid key")
|
||||
raise HTTPException(status_code=401, detail="Invalid key")
|
||||
return key
|
||||
@@ -29,7 +30,7 @@ class SecurityService:
|
||||
)
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token not in self.allowed_tokens:
|
||||
if token not in self.allowed_tokens and token != self.auth_token:
|
||||
logger.error("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
@@ -41,8 +42,19 @@ class SecurityService:
|
||||
logger.error("Missing x-goog-api-key header")
|
||||
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
||||
|
||||
if x_goog_api_key not in self.allowed_tokens:
|
||||
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
|
||||
logger.error("Invalid x-goog-api-key")
|
||||
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
||||
|
||||
return x_goog_api_key
|
||||
|
||||
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
|
||||
if not authorization:
|
||||
logger.error("Missing auth_token header")
|
||||
raise HTTPException(status_code=401, detail="Missing auth_token header")
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token != self.auth_token:
|
||||
logger.error("Invalid auth_token")
|
||||
raise HTTPException(status_code=401, detail="Invalid auth_token")
|
||||
|
||||
return token
|
||||
Reference in New Issue
Block a user