mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-19 23:30:47 +08:00
76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
import asyncio
|
|
from itertools import cycle
|
|
from typing import Dict
|
|
from app.core.logger import get_key_manager_logger
|
|
|
|
logger = get_key_manager_logger()
|
|
|
|
|
|
class KeyManager:
|
|
def __init__(self, api_keys: list):
|
|
self.api_keys = api_keys
|
|
self.key_cycle = cycle(api_keys)
|
|
self.key_cycle_lock = asyncio.Lock()
|
|
self.failure_count_lock = asyncio.Lock()
|
|
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
|
|
self.MAX_FAILURES = 10
|
|
|
|
async def get_next_key(self) -> str:
|
|
"""获取下一个API key"""
|
|
async with self.key_cycle_lock:
|
|
return next(self.key_cycle)
|
|
|
|
async def is_key_valid(self, key: str) -> bool:
|
|
"""检查key是否有效"""
|
|
async with self.failure_count_lock:
|
|
return self.key_failure_counts[key] < self.MAX_FAILURES
|
|
|
|
async def reset_failure_counts(self):
|
|
"""重置所有key的失败计数"""
|
|
async with self.failure_count_lock:
|
|
for key in self.key_failure_counts:
|
|
self.key_failure_counts[key] = 0
|
|
|
|
async def get_next_working_key(self) -> str:
|
|
"""获取下一可用的API key"""
|
|
initial_key = await self.get_next_key()
|
|
current_key = initial_key
|
|
|
|
while True:
|
|
if await self.is_key_valid(current_key):
|
|
return current_key
|
|
|
|
current_key = await self.get_next_key()
|
|
if current_key == initial_key:
|
|
# await self.reset_failure_counts() 取消重置
|
|
return current_key
|
|
|
|
async def handle_api_failure(self, api_key: str) -> str:
|
|
"""处理API调用失败"""
|
|
async with self.failure_count_lock:
|
|
self.key_failure_counts[api_key] += 1
|
|
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
|
|
logger.warning(
|
|
f"API key {api_key} has failed {self.MAX_FAILURES} times"
|
|
)
|
|
|
|
return await self.get_next_working_key()
|
|
|
|
async def get_keys_by_status(self) -> dict:
|
|
"""获取分类后的API key列表"""
|
|
valid_keys = []
|
|
invalid_keys = []
|
|
|
|
async with self.failure_count_lock:
|
|
for key in self.api_keys:
|
|
masked_key = f"{key}"
|
|
if self.key_failure_counts[key] < self.MAX_FAILURES:
|
|
valid_keys.append(masked_key)
|
|
else:
|
|
invalid_keys.append(masked_key)
|
|
|
|
return {
|
|
"valid_keys": valid_keys,
|
|
"invalid_keys": invalid_keys
|
|
}
|