From b3851441f111cbe8968001e747c2ba9ee6dc743f Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Wed, 12 Feb 2025 17:10:02 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=20RetryHandler?= =?UTF-8?q?=20=E8=A3=85=E9=A5=B0=E5=99=A8=E4=BB=A5=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=8A=A8=E6=80=81=20KeyManager=20=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/gemini_routes.py | 4 ++-- app/api/openai_routes.py | 2 +- app/services/chat/retry_handler.py | 9 +++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index b593c48..80f1903 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -47,7 +47,7 @@ async def list_models(_=Depends(security_service.verify_key), @router.post("/models/{model_name}:generateContent") @router_v1beta.post("/models/{model_name}:generateContent") -@RetryHandler(max_retries=3, key_manager=Depends(get_key_manager), key_arg="api_key") +@RetryHandler(max_retries=3, key_arg="api_key") async def generate_content( model_name: str, request: GeminiRequest, @@ -77,7 +77,7 @@ async def generate_content( @router.post("/models/{model_name}:streamGenerateContent") @router_v1beta.post("/models/{model_name}:streamGenerateContent") -@RetryHandler(max_retries=3, key_manager=Depends(get_key_manager), key_arg="api_key") +@RetryHandler(max_retries=3, key_arg="api_key") async def stream_generate_content( model_name: str, request: GeminiRequest, diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index 0d94e42..68f021a 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -46,7 +46,7 @@ async def list_models( @router.post("/v1/chat/completions") @router.post("/hf/v1/chat/completions") -@RetryHandler(max_retries=3, key_manager=Depends(get_key_manager), key_arg="api_key") +@RetryHandler(max_retries=3, key_arg="api_key") async def chat_completion( request: ChatRequest, _=Depends(security_service.verify_authorization), diff --git a/app/services/chat/retry_handler.py b/app/services/chat/retry_handler.py index 28985ff..b83f517 100644 --- a/app/services/chat/retry_handler.py +++ b/app/services/chat/retry_handler.py @@ -12,9 +12,8 @@ logger = get_retry_logger() class RetryHandler: """重试处理装饰器""" - def __init__(self, max_retries: int = 3, key_manager: KeyManager = None, key_arg: str = "api_key"): + def __init__(self, max_retries: int = 3, key_arg: str = "api_key"): self.max_retries = max_retries - self.key_manager = key_manager self.key_arg = key_arg def __call__(self, func: Callable[..., T]) -> Callable[..., T]: @@ -29,9 +28,11 @@ class RetryHandler: last_exception = e logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}") - if self.key_manager: + # 从函数参数中获取 key_manager + key_manager = kwargs.get('key_manager') + if key_manager: old_key = kwargs.get(self.key_arg) - new_key = await self.key_manager.handle_api_failure(old_key) + new_key = await key_manager.handle_api_failure(old_key) kwargs[self.key_arg] = new_key logger.info(f"Switched to new API key: {new_key}")