mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-11 18:09:55 +08:00
feat: Add support for countTokens API and improve debug logging
This commit is contained in:
@@ -80,10 +80,12 @@ class SecurityService:
|
||||
|
||||
# 否则检查请求头中的x-goog-api-key
|
||||
if not x_goog_api_key:
|
||||
logger.debug(f"Failed auth attempt: key='{key}', x_goog_api_key=None")
|
||||
logger.error("Invalid key and missing x-goog-api-key header")
|
||||
raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
|
||||
|
||||
if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN:
|
||||
logger.debug(f"Failed auth attempt: key='{key}', x_goog_api_key='{x_goog_api_key}'")
|
||||
logger.error("Invalid key and invalid x-goog-api-key")
|
||||
raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
|
||||
|
||||
|
||||
@@ -151,6 +151,35 @@ async def stream_generate_content(
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:countTokens")
|
||||
@router_v1beta.post("/models/{model_name}:countTokens")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def count_tokens(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""处理 Gemini token 计数请求。"""
|
||||
operation_name = "gemini_count_tokens"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Token counting failed"):
|
||||
logger.info(f"Handling Gemini token count request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await chat_service.count_tokens(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
||||
|
||||
@@ -195,6 +195,54 @@ class GeminiChatService:
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
async def count_tokens(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""计算token数量"""
|
||||
# countTokens API只需要contents
|
||||
payload = {"contents": request.model_dump().get("contents", [])}
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.count_tokens(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-count-tokens",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
async def stream_generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
|
||||
@@ -21,6 +21,10 @@ class ApiClient(ABC):
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class GeminiApiClient(ApiClient):
|
||||
"""Gemini API客户端"""
|
||||
@@ -108,6 +112,26 @@ class GeminiApiClient(ApiClient):
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for counting tokens: {proxy_to_use}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:countTokens?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
"""OpenAI API客户端"""
|
||||
|
||||
Reference in New Issue
Block a user