import asyncio from itertools import cycle import logging from typing import Dict logger = logging.getLogger(__name__) class KeyManager: def __init__(self, api_keys: list): self.api_keys = api_keys self.key_cycle = cycle(api_keys) self.key_cycle_lock = asyncio.Lock() self.failure_count_lock = asyncio.Lock() self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys} self.MAX_FAILURES = 10 async def get_next_key(self) -> str: """获取下一个API key""" async with self.key_cycle_lock: return next(self.key_cycle) async def is_key_valid(self, key: str) -> bool: """检查key是否有效""" async with self.failure_count_lock: return self.key_failure_counts[key] < self.MAX_FAILURES async def reset_failure_counts(self): """重置所有key的失败计数""" async with self.failure_count_lock: for key in self.key_failure_counts: self.key_failure_counts[key] = 0 async def get_next_working_key(self) -> str: """获取下一个可用的API key""" initial_key = await self.get_next_key() current_key = initial_key while True: if await self.is_key_valid(current_key): return current_key current_key = await self.get_next_key() if current_key == initial_key: await self.reset_failure_counts() return current_key async def handle_api_failure(self, api_key: str) -> str: """处理API调用失败""" async with self.failure_count_lock: self.key_failure_counts[api_key] += 1 if self.key_failure_counts[api_key] >= self.MAX_FAILURES: logger.warning( f"API key {api_key} has failed {self.MAX_FAILURES} times" ) return await self.get_next_working_key()