diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index cf330b3..fa93e6d 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -31,6 +31,11 @@ async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager return await key_manager.get_next_working_key() +async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)): + """获取Gemini聊天服务实例""" + return GeminiChatService(settings.BASE_URL, key_manager) + + @router.get("/models") @router_v1beta.get("/models") async def list_models( @@ -88,7 +93,7 @@ async def generate_content( request: GeminiRequest, _=Depends(security_service.verify_key_or_goog_api_key), api_key: str = Depends(get_next_working_key), - key_manager: KeyManager = Depends(get_key_manager) + chat_service: GeminiChatService = Depends(get_chat_service) ): """非流式生成内容""" logger.info("-" * 50 + "gemini_generate_content" + "-" * 50) @@ -100,7 +105,6 @@ async def generate_content( raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") try: - chat_service = GeminiChatService(settings.BASE_URL, key_manager) response = await chat_service.generate_content( model=model_name, request=request, @@ -120,7 +124,7 @@ async def stream_generate_content( request: GeminiRequest, _=Depends(security_service.verify_key_or_goog_api_key), api_key: str = Depends(get_next_working_key), - key_manager: KeyManager = Depends(get_key_manager) + chat_service: GeminiChatService = Depends(get_chat_service) ): """流式生成内容""" logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50) @@ -132,7 +136,6 @@ async def stream_generate_content( raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") try: - chat_service = GeminiChatService(settings.BASE_URL, key_manager) response_stream = chat_service.stream_generate_content( model=model_name, request=request, @@ -145,14 +148,12 @@ async def stream_generate_content( @router.post("/verify-key/{api_key}") -async def verify_key(api_key: str): +async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service)): """验证Gemini API密钥的有效性""" logger.info("-" * 50 + "verify_gemini_key" + "-" * 50) logger.info("Verifying API key validity") try: - key_manager = await get_key_manager() - chat_service = GeminiChatService(settings.BASE_URL, key_manager) # 使用generate_content接口测试key的有效性 gemini_request = GeminiRequest( diff --git a/app/router/openai_routes.py b/app/router/openai_routes.py index a378211..16ce913 100644 --- a/app/router/openai_routes.py +++ b/app/router/openai_routes.py @@ -36,6 +36,11 @@ async def get_next_working_key_wrapper( return await key_manager.get_next_working_key() +async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)): + """获取OpenAI聊天服务实例""" + return OpenAIChatService(settings.BASE_URL, key_manager) + + @router.get("/v1/models") @router.get("/hf/v1/models") async def list_models( @@ -62,12 +67,12 @@ async def chat_completion( request: ChatRequest, _=Depends(security_service.verify_authorization), api_key: str = Depends(get_next_working_key_wrapper), - key_manager: KeyManager = Depends(get_key_manager), + key_manager: KeyManager = Depends(get_key_manager), # 保留 key_manager 用于获取 paid_key + chat_service: OpenAIChatService = Depends(get_openai_chat_service), ): # 如果model是imagen3,使用paid_key if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": api_key = await key_manager.get_paid_key() - chat_service = OpenAIChatService(settings.BASE_URL, key_manager) logger.info("-" * 50 + "chat_completion" + "-" * 50) logger.info(f"Handling chat completion request for model: {request.model}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}")