diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index b593c48..80f1903 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -47,7 +47,7 @@ async def list_models(_=Depends(security_service.verify_key), @router.post("/models/{model_name}:generateContent") @router_v1beta.post("/models/{model_name}:generateContent") -@RetryHandler(max_retries=3, key_manager=Depends(get_key_manager), key_arg="api_key") +@RetryHandler(max_retries=3, key_arg="api_key") async def generate_content( model_name: str, request: GeminiRequest, @@ -77,7 +77,7 @@ async def generate_content( @router.post("/models/{model_name}:streamGenerateContent") @router_v1beta.post("/models/{model_name}:streamGenerateContent") -@RetryHandler(max_retries=3, key_manager=Depends(get_key_manager), key_arg="api_key") +@RetryHandler(max_retries=3, key_arg="api_key") async def stream_generate_content( model_name: str, request: GeminiRequest, diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index 0d94e42..68f021a 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -46,7 +46,7 @@ async def list_models( @router.post("/v1/chat/completions") @router.post("/hf/v1/chat/completions") -@RetryHandler(max_retries=3, key_manager=Depends(get_key_manager), key_arg="api_key") +@RetryHandler(max_retries=3, key_arg="api_key") async def chat_completion( request: ChatRequest, _=Depends(security_service.verify_authorization), diff --git a/app/services/chat/retry_handler.py b/app/services/chat/retry_handler.py index 28985ff..b83f517 100644 --- a/app/services/chat/retry_handler.py +++ b/app/services/chat/retry_handler.py @@ -12,9 +12,8 @@ logger = get_retry_logger() class RetryHandler: """重试处理装饰器""" - def __init__(self, max_retries: int = 3, key_manager: KeyManager = None, key_arg: str = "api_key"): + def __init__(self, max_retries: int = 3, key_arg: str = "api_key"): self.max_retries = max_retries - self.key_manager = key_manager self.key_arg = key_arg def __call__(self, func: Callable[..., T]) -> Callable[..., T]: @@ -29,9 +28,11 @@ class RetryHandler: last_exception = e logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}") - if self.key_manager: + # 从函数参数中获取 key_manager + key_manager = kwargs.get('key_manager') + if key_manager: old_key = kwargs.get(self.key_arg) - new_key = await self.key_manager.handle_api_failure(old_key) + new_key = await key_manager.handle_api_failure(old_key) kwargs[self.key_arg] = new_key logger.info(f"Switched to new API key: {new_key}")