Files
gemini-balance/app/services/key_manager.py

106 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
from itertools import cycle
from typing import Dict
from app.core.logger import get_key_manager_logger
from app.core.config import settings
logger = get_key_manager_logger()
class KeyManager:
def __init__(self, api_keys: list):
self.api_keys = api_keys
self.key_cycle = cycle(api_keys)
self.key_cycle_lock = asyncio.Lock()
self.failure_count_lock = asyncio.Lock()
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
self.MAX_FAILURES = settings.MAX_FAILURES
self.paid_key = settings.PAID_KEY
async def get_paid_key(self) -> str:
return self.paid_key
async def get_next_key(self) -> str:
"""获取下一个API key"""
async with self.key_cycle_lock:
return next(self.key_cycle)
async def is_key_valid(self, key: str) -> bool:
"""检查key是否有效"""
async with self.failure_count_lock:
return self.key_failure_counts[key] < self.MAX_FAILURES
async def reset_failure_counts(self):
"""重置所有key的失败计数"""
async with self.failure_count_lock:
for key in self.key_failure_counts:
self.key_failure_counts[key] = 0
async def get_next_working_key(self) -> str:
"""获取下一可用的API key"""
initial_key = await self.get_next_key()
current_key = initial_key
while True:
if await self.is_key_valid(current_key):
return current_key
current_key = await self.get_next_key()
if current_key == initial_key:
# await self.reset_failure_counts() 取消重置
return current_key
async def handle_api_failure(self, api_key: str) -> str:
"""处理API调用失败"""
async with self.failure_count_lock:
self.key_failure_counts[api_key] += 1
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
logger.warning(
f"API key {api_key} has failed {self.MAX_FAILURES} times"
)
return await self.get_next_working_key()
def get_fail_count(self, key: str) -> int:
"""获取指定密钥的失败次数"""
return self.key_failure_counts.get(key, 0)
async def get_keys_by_status(self) -> dict:
"""获取分类后的API key列表包括失败次数"""
valid_keys = {}
invalid_keys = {}
async with self.failure_count_lock:
for key in self.api_keys:
fail_count = self.key_failure_counts[key]
if fail_count < self.MAX_FAILURES:
valid_keys[key] = fail_count
else:
invalid_keys[key] = fail_count
return {
"valid_keys": valid_keys,
"invalid_keys": invalid_keys
}
_singleton_instance = None
_singleton_lock = asyncio.Lock()
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
"""
获取 KeyManager 单例实例。
如果尚未创建实例,将使用提供的 api_keys 初始化 KeyManager。
如果已创建实例,则忽略 api_keys 参数,返回现有单例。
"""
global _singleton_instance
async with _singleton_lock:
if _singleton_instance is None:
if api_keys is None:
raise ValueError("API keys are required to initialize the KeyManager")
_singleton_instance = KeyManager(api_keys)
return _singleton_instance