mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-12 02:19:59 +08:00
新增API密钥列表查询接口并优化Gemini安全设置
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from http.client import HTTPException
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from typing import Optional
|
||||
import logging
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
@@ -97,3 +97,29 @@ async def embedding(
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/v1/keys/list")
|
||||
@router.get("/hf/v1/keys/list")
|
||||
async def get_keys_list(
|
||||
authorization: str = Header(None),
|
||||
token: str = Depends(security_service.verify_authorization),
|
||||
):
|
||||
"""获取有效和无效的API key列表"""
|
||||
logger.info("Handling keys list request")
|
||||
try:
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
return {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"]
|
||||
},
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keys list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error while fetching keys list"
|
||||
)
|
||||
|
||||
@@ -174,6 +174,13 @@ class ChatService:
|
||||
"contents": gemini_messages,
|
||||
"generationConfig": {"temperature": temperature},
|
||||
"tools": tools,
|
||||
"safetySettings": [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
],
|
||||
}
|
||||
|
||||
if stream:
|
||||
@@ -241,39 +248,39 @@ class ChatService:
|
||||
logger.error(f"Error in non-stream completion: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _openai_chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
stream: bool,
|
||||
api_key: str,
|
||||
tools: Optional[list] = None,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""Handle OpenAI API chat completion"""
|
||||
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
if tools:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
response = client.chat.completions.create(
|
||||
model=model, messages=messages, temperature=temperature, stream=stream
|
||||
)
|
||||
# async def _openai_chat_completion(
|
||||
# self,
|
||||
# messages: list,
|
||||
# model: str,
|
||||
# temperature: float,
|
||||
# stream: bool,
|
||||
# api_key: str,
|
||||
# tools: Optional[list] = None,
|
||||
# ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
# """Handle OpenAI API chat completion"""
|
||||
# client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
# if tools:
|
||||
# response = client.chat.completions.create(
|
||||
# model=model,
|
||||
# messages=messages,
|
||||
# temperature=temperature,
|
||||
# stream=stream,
|
||||
# tools=tools,
|
||||
# )
|
||||
# else:
|
||||
# response = client.chat.completions.create(
|
||||
# model=model, messages=messages, temperature=temperature, stream=stream
|
||||
# )
|
||||
|
||||
if stream:
|
||||
# if stream:
|
||||
|
||||
async def generate():
|
||||
for chunk in response:
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
# async def generate():
|
||||
# for chunk in response:
|
||||
# yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
return generate()
|
||||
# return generate()
|
||||
|
||||
return response
|
||||
# return response
|
||||
|
||||
def format_code_block(self, code_data: dict) -> str:
|
||||
"""格式化代码块输出"""
|
||||
|
||||
@@ -32,7 +32,7 @@ class KeyManager:
|
||||
self.key_failure_counts[key] = 0
|
||||
|
||||
async def get_next_working_key(self) -> str:
|
||||
"""获取下一个可用的API key"""
|
||||
"""获取下一可用的API key"""
|
||||
initial_key = await self.get_next_key()
|
||||
current_key = initial_key
|
||||
|
||||
@@ -55,3 +55,21 @@ class KeyManager:
|
||||
)
|
||||
|
||||
return await self.get_next_working_key()
|
||||
|
||||
async def get_keys_by_status(self) -> dict:
|
||||
"""获取分类后的API key列表"""
|
||||
valid_keys = []
|
||||
invalid_keys = []
|
||||
|
||||
async with self.failure_count_lock:
|
||||
for key in self.api_keys:
|
||||
masked_key = f"{key}"
|
||||
if self.key_failure_counts[key] < self.MAX_FAILURES:
|
||||
valid_keys.append(masked_key)
|
||||
else:
|
||||
invalid_keys.append(masked_key)
|
||||
|
||||
return {
|
||||
"valid_keys": valid_keys,
|
||||
"invalid_keys": invalid_keys
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user