From bd9dbfb0e39e5c994a4b991a735b1c9c07f337ea Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Sun, 15 Dec 2024 17:17:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9EAPI=E5=AF=86=E9=92=A5?= =?UTF-8?q?=E5=88=97=E8=A1=A8=E6=9F=A5=E8=AF=A2=E6=8E=A5=E5=8F=A3=E5=B9=B6?= =?UTF-8?q?=E4=BC=98=E5=8C=96Gemini=E5=AE=89=E5=85=A8=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/routes.py | 28 +++++++++++++++- app/services/chat_service.py | 65 ++++++++++++++++++++---------------- app/services/key_manager.py | 20 ++++++++++- 3 files changed, 82 insertions(+), 31 deletions(-) diff --git a/app/api/routes.py b/app/api/routes.py index 36f0a57..e87c8da 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -1,5 +1,5 @@ +from http.client import HTTPException from fastapi import APIRouter, Depends, Header -from typing import Optional import logging from fastapi.responses import StreamingResponse @@ -97,3 +97,29 @@ async def embedding( except Exception as e: logger.error(f"Embedding request failed: {str(e)}") raise + + +@router.get("/v1/keys/list") +@router.get("/hf/v1/keys/list") +async def get_keys_list( + authorization: str = Header(None), + token: str = Depends(security_service.verify_authorization), +): + """获取有效和无效的API key列表""" + logger.info("Handling keys list request") + try: + keys_status = await key_manager.get_keys_by_status() + return { + "status": "success", + "data": { + "valid_keys": keys_status["valid_keys"], + "invalid_keys": keys_status["invalid_keys"] + }, + "total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]) + } + except Exception as e: + logger.error(f"Error getting keys list: {str(e)}") + raise HTTPException( + status_code=500, + detail="Internal server error while fetching keys list" + ) diff --git a/app/services/chat_service.py b/app/services/chat_service.py index 2b33a39..6d90ffb 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -174,6 +174,13 @@ class ChatService: "contents": gemini_messages, "generationConfig": {"temperature": temperature}, "tools": tools, + "safetySettings": [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + ], } if stream: @@ -241,39 +248,39 @@ class ChatService: logger.error(f"Error in non-stream completion: {str(e)}") raise - async def _openai_chat_completion( - self, - messages: list, - model: str, - temperature: float, - stream: bool, - api_key: str, - tools: Optional[list] = None, - ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: - """Handle OpenAI API chat completion""" - client = openai.OpenAI(api_key=api_key, base_url=self.base_url) - if tools: - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - stream=stream, - tools=tools, - ) - else: - response = client.chat.completions.create( - model=model, messages=messages, temperature=temperature, stream=stream - ) + # async def _openai_chat_completion( + # self, + # messages: list, + # model: str, + # temperature: float, + # stream: bool, + # api_key: str, + # tools: Optional[list] = None, + # ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + # """Handle OpenAI API chat completion""" + # client = openai.OpenAI(api_key=api_key, base_url=self.base_url) + # if tools: + # response = client.chat.completions.create( + # model=model, + # messages=messages, + # temperature=temperature, + # stream=stream, + # tools=tools, + # ) + # else: + # response = client.chat.completions.create( + # model=model, messages=messages, temperature=temperature, stream=stream + # ) - if stream: + # if stream: - async def generate(): - for chunk in response: - yield f"data: {chunk.model_dump_json()}\n\n" + # async def generate(): + # for chunk in response: + # yield f"data: {chunk.model_dump_json()}\n\n" - return generate() + # return generate() - return response + # return response def format_code_block(self, code_data: dict) -> str: """格式化代码块输出""" diff --git a/app/services/key_manager.py b/app/services/key_manager.py index c9cb5ca..1a1b308 100644 --- a/app/services/key_manager.py +++ b/app/services/key_manager.py @@ -32,7 +32,7 @@ class KeyManager: self.key_failure_counts[key] = 0 async def get_next_working_key(self) -> str: - """获取下一个可用的API key""" + """获取下一可用的API key""" initial_key = await self.get_next_key() current_key = initial_key @@ -55,3 +55,21 @@ class KeyManager: ) return await self.get_next_working_key() + + async def get_keys_by_status(self) -> dict: + """获取分类后的API key列表""" + valid_keys = [] + invalid_keys = [] + + async with self.failure_count_lock: + for key in self.api_keys: + masked_key = f"{key}" + if self.key_failure_counts[key] < self.MAX_FAILURES: + valid_keys.append(masked_key) + else: + invalid_keys.append(masked_key) + + return { + "valid_keys": valid_keys, + "invalid_keys": invalid_keys + }