From 1199d7cc3cf9d138fba7c77d8ec2a98d53828fe3 Mon Sep 17 00:00:00 2001 From: lc631017672 <631017672@qq.com> Date: Mon, 7 Jul 2025 10:08:57 +0800 Subject: [PATCH] feat: Add support for countTokens API and improve debug logging --- app/core/security.py | 2 ++ app/router/gemini_routes.py | 29 +++++++++++++++ app/service/chat/gemini_chat_service.py | 48 +++++++++++++++++++++++++ app/service/client/api_client.py | 24 +++++++++++++ 4 files changed, 103 insertions(+) diff --git a/app/core/security.py b/app/core/security.py index eebad69..77b1e27 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -80,10 +80,12 @@ class SecurityService: # 否则检查请求头中的x-goog-api-key if not x_goog_api_key: + logger.debug(f"Failed auth attempt: key='{key}', x_goog_api_key=None") logger.error("Invalid key and missing x-goog-api-key header") raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header") if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN: + logger.debug(f"Failed auth attempt: key='{key}', x_goog_api_key='{x_goog_api_key}'") logger.error("Invalid key and invalid x-goog-api-key") raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key") diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index 95bf88a..1b5858b 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -151,6 +151,35 @@ async def stream_generate_content( return StreamingResponse(response_stream, media_type="text/event-stream") +@router.post("/models/{model_name}:countTokens") +@router_v1beta.post("/models/{model_name}:countTokens") +@RetryHandler(key_arg="api_key") +async def count_tokens( + model_name: str, + request: GeminiRequest, + _=Depends(security_service.verify_key_or_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager), + chat_service: GeminiChatService = Depends(get_chat_service) +): + """处理 Gemini token 计数请求。""" + operation_name = "gemini_count_tokens" + async with handle_route_errors(logger, operation_name, failure_message="Token counting failed"): + logger.info(f"Handling Gemini token count request for model: {model_name}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {api_key}") + + if not await model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + + response = await chat_service.count_tokens( + model=model_name, + request=request, + api_key=api_key + ) + return response + + @router.post("/reset-all-fail-counts") async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)): """批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥""" diff --git a/app/service/chat/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py index cca82ae..622ad28 100644 --- a/app/service/chat/gemini_chat_service.py +++ b/app/service/chat/gemini_chat_service.py @@ -195,6 +195,54 @@ class GeminiChatService: request_time=request_datetime ) + async def count_tokens( + self, model: str, request: GeminiRequest, api_key: str + ) -> Dict[str, Any]: + """计算token数量""" + # countTokens API只需要contents + payload = {"contents": request.model_dump().get("contents", [])} + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + + try: + response = await self.api_client.count_tokens(payload, model, api_key) + is_success = True + status_code = 200 + return response + except Exception as e: + is_success = False + error_log_msg = str(e) + logger.error(f"Count tokens API call failed with error: {error_log_msg}") + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="gemini-count-tokens", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime + ) + async def stream_generate_content( self, model: str, request: GeminiRequest, api_key: str ) -> AsyncGenerator[str, None]: diff --git a/app/service/client/api_client.py b/app/service/client/api_client.py index 10d4391..641b8d6 100644 --- a/app/service/client/api_client.py +++ b/app/service/client/api_client.py @@ -21,6 +21,10 @@ class ApiClient(ABC): async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: pass + @abstractmethod + async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + pass + class GeminiApiClient(ApiClient): """Gemini API客户端""" @@ -108,6 +112,26 @@ class GeminiApiClient(ApiClient): async for line in response.aiter_lines(): yield line + async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + model = self._get_real_model(model) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for counting tokens: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/models/{model}:countTokens?key={api_key}" + response = await client.post(url, json=payload) + if response.status_code != 200: + error_content = response.text + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + return response.json() + class OpenaiApiClient(ApiClient): """OpenAI API客户端"""