From 8dfe617468a9da397b1ef8b9533e4fa9e59b930e Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Wed, 18 Dec 2024 21:35:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=B0=86=20Gemini=20API=20=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E8=BF=81=E7=A7=BB=E8=87=B3=20ChatService=20=E5=B9=B6?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20API=20Key=20=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/gemini_routes.py | 24 +++++++---------- app/core/security.py | 12 +++++++++ app/schemas/gemini_models.py | 8 +++--- app/services/chat_service.py | 52 ++++++++++++++++++++++++++---------- requirements.txt | 1 + 5 files changed, 64 insertions(+), 33 deletions(-) diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index 57e5380..2d2340d 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -17,6 +17,7 @@ logger = get_gemini_logger() security_service = SecurityService(settings.ALLOWED_TOKENS) key_manager = KeyManager(settings.API_KEYS) model_service = ModelService(settings.MODEL_SEARCH) +chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager) @router.get("/models") async def list_models( @@ -32,13 +33,13 @@ async def list_models( @router.post("/models/{model_name}:generateContent") async def generate_content( + model_name: str, request: GeminiRequest, - # key: str = None, - # token: str = Depends(security_service.verify_key), + x_goog_api_key: str = Depends(security_service.verify_goog_api_key), ): """非流式生成内容""" logger.info("-" * 50 + "gemini_generate_content" + "-" * 50) - logger.info(f"Handling Gemini content generation request for model: {request.model}") + logger.info(f"Handling Gemini content generation request for model: {model_name}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}") api_key = await key_manager.get_next_working_key() @@ -48,13 +49,9 @@ async def generate_content( while retries < MAX_RETRIES: try: - response = await model_service.generate_content( - contents=request.contents, - model=request.model, - temperature=request.temperature, - candidate_count=request.candidate_count, - top_p=request.top_p, - top_k=request.top_k, + response = await chat_service.generate_content( + model_name=model_name, + request=request, api_key=api_key ) return response @@ -68,14 +65,12 @@ async def generate_content( retries += 1 if retries >= MAX_RETRIES: logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error") - raise @router.post("/models/{model_name}:streamGenerateContent") async def stream_generate_content( model_name: str, request: GeminiRequest, - # x_goog_api_key: str = Header("x-goog-api-key"), - # token: str = Depends(security_service.verify_key), + x_goog_api_key: str = Depends(security_service.verify_goog_api_key), ): """流式生成内容""" logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50) @@ -94,5 +89,4 @@ async def stream_generate_content( return StreamingResponse(response_stream, media_type="text/event-stream") except Exception as e: - logger.error(f"Streaming request failed: {str(e)}") - raise \ No newline at end of file + logger.error(f"Streaming request failed: {str(e)}") \ No newline at end of file diff --git a/app/core/security.py b/app/core/security.py index 0734f14..a23f20d 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -34,3 +34,15 @@ class SecurityService: raise HTTPException(status_code=401, detail="Invalid token") return token + + async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str: + """验证Google API Key""" + if not x_goog_api_key: + logger.error("Missing x-goog-api-key header") + raise HTTPException(status_code=401, detail="Missing x-goog-api-key header") + + if x_goog_api_key not in self.allowed_tokens: + logger.error("Invalid x-goog-api-key") + raise HTTPException(status_code=401, detail="Invalid x-goog-api-key") + + return x_goog_api_key diff --git a/app/schemas/gemini_models.py b/app/schemas/gemini_models.py index a43ba8c..0777457 100644 --- a/app/schemas/gemini_models.py +++ b/app/schemas/gemini_models.py @@ -3,7 +3,7 @@ from pydantic import BaseModel class SafetySetting(BaseModel): category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT"]] = None - threshold: Optional[Literal["HARM_BLOCK", "HARM_FLAG", "HARM_UNSPECIFIED"]] = None + threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE","BLOCK_ONLY_HIGH","BLOCK_NONE","OFF"]] = None class GenerationConfig(BaseModel): @@ -33,7 +33,7 @@ class GeminiContent(BaseModel): class GeminiRequest(BaseModel): contents: List[GeminiContent] - # tools: Optional[List[Dict[str, Any]]] = None - # safetySettings: Optional[List[SafetySetting]] = None + tools: Optional[List[Dict[str, Any]]] = [] + safetySettings: Optional[List[SafetySetting]] = None generationConfig: Optional[GenerationConfig] = None - # systemInstruction: Optional[SystemInstruction] = None \ No newline at end of file + systemInstruction: Optional[SystemInstruction] = None \ No newline at end of file diff --git a/app/services/chat_service.py b/app/services/chat_service.py index 8233cda..7f22331 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -312,21 +312,45 @@ class ChatService: api_key: str ) -> AsyncGenerator: """调用Gemini API流式生成内容""" - url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={api_key}" - - timeout = httpx.Timeout(60.0, read=60.0) # 连接超时60秒,读取超时60秒 - async with httpx.AsyncClient(timeout=timeout) as client: + retries = 0 + MAX_RETRIES = 3 + current_api_key = api_key + + while retries < MAX_RETRIES: try: - async with client.stream('POST', url, json=request.model_dump()) as response: - if response.status_code != 200: - error_text = response.text - logger.error(f"Error: {response.status_code}") - logger.error(error_text) - raise Exception(f"API request failed with status {response.status_code}: {error_text}") - - async for line in response.aiter_lines(): - print(line) - yield line + "\n\n" + url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={current_api_key}" + timeout = httpx.Timeout(60.0, read=60.0) + + async with httpx.AsyncClient(timeout=timeout) as client: + async with client.stream('POST', url, json=request.model_dump()) as response: + if response.status_code != 200: + error_text = await response.text() + logger.error(f"Error: {response.status_code}: {error_text}") + if retries < MAX_RETRIES - 1: + current_api_key = await self.key_manager.handle_api_failure(current_api_key) + logger.info(f"Switched to new API key: {current_api_key}") + retries += 1 + continue + raise Exception(f"API request failed with status {response.status_code}: {error_text}") + + async for line in response.aiter_lines(): + yield line + "\n\n" + return + + except httpx.ReadTimeout: + logger.warning(f"Read timeout occurred, attempting retry {retries + 1}") + if retries < MAX_RETRIES - 1: + current_api_key = await self.key_manager.handle_api_failure(current_api_key) + logger.info(f"Switched to new API key: {current_api_key}") + retries += 1 + continue + raise + except Exception as e: logger.error(f"Streaming request failed: {str(e)}") + if retries < MAX_RETRIES - 1: + current_api_key = await self.key_manager.handle_api_failure(current_api_key) + logger.info(f"Switched to new API key: {current_api_key}") + retries += 1 + continue raise \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4f906fd..347af8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ openai pydantic pydantic_settings requests +starlette uvicorn \ No newline at end of file