diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index 14c828b..056f471 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -13,7 +13,7 @@ router = APIRouter(prefix="/gemini/v1beta") logger = get_gemini_logger() # 初始化服务 -security_service = SecurityService(settings.ALLOWED_TOKENS) +security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) key_manager = KeyManager(settings.API_KEYS) model_service = ModelService(settings.MODEL_SEARCH) chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager) diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index 0ec55d6..bd0a27e 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -15,7 +15,7 @@ router = APIRouter() logger = get_openai_logger() # 初始化服务 -security_service = SecurityService(settings.ALLOWED_TOKENS) +security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) key_manager = KeyManager(settings.API_KEYS) model_service = ModelService(settings.MODEL_SEARCH) chat_service = ChatService(settings.BASE_URL, key_manager) @@ -100,7 +100,7 @@ async def embedding( @router.get("/hf/v1/keys/list") async def get_keys_list( authorization: str = Header(None), - token: str = Depends(security_service.verify_authorization), + token: str = Depends(security_service.verify_auth_token), ): """获取有效和无效的API key列表""" logger.info("-" * 50 + "get_keys_list" + "-" * 50) diff --git a/app/core/config.py b/app/core/config.py index cba379b..dcd2ccf 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -10,7 +10,13 @@ class Settings(BaseSettings): TOOLS_CODE_EXECUTION_ENABLED: bool = False SHOW_SEARCH_LINK: bool = True SHOW_THINKING_PROCESS: bool = True - + AUTH_TOKEN: str + + def __init__(self): + super().__init__() + if not self.AUTH_TOKEN: + self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else "" + class Config: env_file = ".env" diff --git a/app/core/security.py b/app/core/security.py index a23f20d..f2508fa 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -6,11 +6,12 @@ logger = get_security_logger() class SecurityService: - def __init__(self, allowed_tokens: list): + def __init__(self, allowed_tokens: list, auth_token: str): self.allowed_tokens = allowed_tokens + self.auth_token = auth_token async def verify_key(self, key: str): - if key not in self.allowed_tokens: + if key not in self.allowed_tokens and key != self.auth_token: logger.error("Invalid key") raise HTTPException(status_code=401, detail="Invalid key") return key @@ -29,7 +30,7 @@ class SecurityService: ) token = authorization.replace("Bearer ", "") - if token not in self.allowed_tokens: + if token not in self.allowed_tokens and token != self.auth_token: logger.error("Invalid token") raise HTTPException(status_code=401, detail="Invalid token") @@ -41,8 +42,19 @@ class SecurityService: logger.error("Missing x-goog-api-key header") raise HTTPException(status_code=401, detail="Missing x-goog-api-key header") - if x_goog_api_key not in self.allowed_tokens: + if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token: logger.error("Invalid x-goog-api-key") raise HTTPException(status_code=401, detail="Invalid x-goog-api-key") return x_goog_api_key + + async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str: + if not authorization: + logger.error("Missing auth_token header") + raise HTTPException(status_code=401, detail="Missing auth_token header") + token = authorization.replace("Bearer ", "") + if token != self.auth_token: + logger.error("Invalid auth_token") + raise HTTPException(status_code=401, detail="Invalid auth_token") + + return token \ No newline at end of file