From baf643e884b2087f2c23a03e8927a21c6b87ca68 Mon Sep 17 00:00:00 2001 From: snaily Date: Thu, 3 Apr 2025 03:12:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E8=B6=85=E6=97=B6=E9=85=8D=E7=BD=AE=E5=8F=8A=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8=E6=8E=A5=E5=8F=A3api=5Fkey?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 新增功能: - 在`.env.example`中添加`TIME_OUT=300`配置项(包含中文注释) - 在`Settings`类中增加`TIME_OUT`字段(读取自`DEFAULT_TIMEOUT`) 2. 优化内容: - 生成配置: * 为`GenerationConfig`设置默认温度/TOP_P/TOP_K值 * 移除`maxOutputTokens`默认值,改为可选传递 - OpenAI请求: * 移除`max_tokens`默认值 * 只有当`max_tokens`有值时才添加到请求payload - 日志优化: * 注释掉`stream_optimizer.py`中部分调试日志 3. 模型列表接口api_key获取方式 --- .env.example | 2 ++ app/config/config.py | 3 ++- app/domain/gemini_models.py | 8 +++++--- app/domain/openai_models.py | 4 ++-- app/handler/stream_optimizer.py | 8 ++++---- app/router/gemini_routes.py | 2 +- app/router/openai_routes.py | 2 +- app/service/chat/gemini_chat_service.py | 5 +++++ app/service/chat/openai_chat_service.py | 3 ++- app/service/key/key_manager.py | 7 +++++++ 10 files changed, 31 insertions(+), 13 deletions(-) diff --git a/.env.example b/.env.example index 49587d1..0df009a 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,8 @@ SHOW_SEARCH_LINK=true SHOW_THINKING_PROCESS=true BASE_URL=https://generativelanguage.googleapis.com/v1beta MAX_FAILURES=10 +# 请求超时时间(秒) +TIME_OUT=300 #########################image_generate 相关配置########################### PAID_KEY=AIzaSyxxxxxxxxxxxxxxxxxxx CREATE_IMAGE_MODEL=imagen-3.0-generate-002 diff --git a/app/config/config.py b/app/config/config.py index 8630cc6..78cf94a 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -4,7 +4,7 @@ from typing import List from pydantic_settings import BaseSettings -from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD +from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, DEFAULT_TIMEOUT class Settings(BaseSettings): @@ -16,6 +16,7 @@ class Settings(BaseSettings): AUTH_TOKEN: str = "" MAX_FAILURES: int = 3 TEST_MODEL: str = DEFAULT_MODEL + TIME_OUT: int = DEFAULT_TIMEOUT # 模型相关配置 SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"] diff --git a/app/domain/gemini_models.py b/app/domain/gemini_models.py index d2c07df..9da1820 100644 --- a/app/domain/gemini_models.py +++ b/app/domain/gemini_models.py @@ -1,6 +1,8 @@ from typing import List, Optional, Dict, Any, Literal, Union from pydantic import BaseModel +from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P + class SafetySetting(BaseModel): category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None @@ -13,9 +15,9 @@ class GenerationConfig(BaseModel): responseSchema: Optional[Dict[str, Any]] = None candidateCount: Optional[int] = 1 maxOutputTokens: Optional[int] = None - temperature: Optional[float] = None - topP: Optional[float] = None - topK: Optional[int] = None + temperature: Optional[float] = DEFAULT_TEMPERATURE + topP: Optional[float] = DEFAULT_TOP_P + topK: Optional[int] = DEFAULT_TOP_K presencePenalty: Optional[float] = None frequencyPenalty: Optional[float] = None responseLogprobs: Optional[bool] = None diff --git a/app/domain/openai_models.py b/app/domain/openai_models.py index 69a28b0..ad6d326 100644 --- a/app/domain/openai_models.py +++ b/app/domain/openai_models.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from typing import List, Optional, Union -from app.core.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P +from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P class ChatRequest(BaseModel): @@ -10,7 +10,7 @@ class ChatRequest(BaseModel): temperature: Optional[float] = DEFAULT_TEMPERATURE stream: Optional[bool] = False tools: Optional[List[dict]] = [] - max_tokens: Optional[int] = DEFAULT_MAX_TOKENS + max_tokens: Optional[int] = None top_p: Optional[float] = DEFAULT_TOP_P top_k: Optional[int] = DEFAULT_TOP_K stop: Optional[List[str]] = [] diff --git a/app/handler/stream_optimizer.py b/app/handler/stream_optimizer.py index 2356da3..cf1332d 100644 --- a/app/handler/stream_optimizer.py +++ b/app/handler/stream_optimizer.py @@ -107,15 +107,15 @@ class StreamOptimizer: # 计算智能延迟时间 delay = self.calculate_delay(len(text)) - if self.logger: - self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s") + # if self.logger: + # self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s") # 根据文本长度决定输出方式 if len(text) >= self.long_text_threshold: # 长文本:分块输出 chunks = self.split_text_into_chunks(text) - if self.logger: - self.logger.info(f"Long text: splitting into {len(chunks)} chunks") + # if self.logger: + # self.logger.info(f"Long text: splitting into {len(chunks)} chunks") for chunk_text in chunks: chunk_response = create_response_chunk(chunk_text) yield format_chunk(chunk_response) diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index a8163c0..cf330b3 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -41,7 +41,7 @@ async def list_models( logger.info("-" * 50 + "list_gemini_models" + "-" * 50) logger.info("Handling Gemini models list request") - api_key = await key_manager.get_next_working_key() + api_key = await key_manager.get_first_valid_key() logger.info(f"Using API key: {api_key}") models_json = model_service.get_gemini_models(api_key) diff --git a/app/router/openai_routes.py b/app/router/openai_routes.py index a99dc34..a378211 100644 --- a/app/router/openai_routes.py +++ b/app/router/openai_routes.py @@ -44,7 +44,7 @@ async def list_models( ): logger.info("-" * 50 + "list_models" + "-" * 50) logger.info("Handling models list request") - api_key = await key_manager.get_next_working_key() + api_key = await key_manager.get_first_valid_key() logger.info(f"Using API key: {api_key}") try: return model_service.get_gemini_openai_models(api_key) diff --git a/app/service/chat/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py index b109592..bda8484 100644 --- a/app/service/chat/gemini_chat_service.py +++ b/app/service/chat/gemini_chat_service.py @@ -89,6 +89,11 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]: def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: """构建请求payload""" request_dict = request.model_dump() + if request.generationConfig: + if request.generationConfig.maxOutputTokens is None: + # 如果未指定最大输出长度,则不传递该字段,解决截断的问题 + request_dict["generationConfig"].pop("maxOutputTokens") + payload = { "contents": request_dict.get("contents", []), "tools": _build_tools(model, request_dict), diff --git a/app/service/chat/openai_chat_service.py b/app/service/chat/openai_chat_service.py index b967c56..4b94816 100644 --- a/app/service/chat/openai_chat_service.py +++ b/app/service/chat/openai_chat_service.py @@ -115,7 +115,6 @@ def _build_payload( "contents": messages, "generationConfig": { "temperature": request.temperature, - "maxOutputTokens": request.max_tokens, "stopSequences": request.stop, "topP": request.top_p, "topK": request.top_k, @@ -123,6 +122,8 @@ def _build_payload( "tools": _build_tools(request, messages), "safetySettings": _get_safety_settings(request.model), } + if request.max_tokens is not None: + payload["generationConfig"]["maxOutputTokens"] = request.max_tokens if request.model.endswith("-image") or request.model.endswith("-image-generation"): payload["generationConfig"]["responseModalities"] = ["Text", "Image"] diff --git a/app/service/key/key_manager.py b/app/service/key/key_manager.py index e50d3d3..096ac5d 100644 --- a/app/service/key/key_manager.py +++ b/app/service/key/key_manager.py @@ -81,6 +81,13 @@ class KeyManager: return {"valid_keys": valid_keys, "invalid_keys": invalid_keys} + async def get_first_valid_key(self) -> str: + """获取第一个有效的API key""" + async with self.failure_count_lock: + for key in self.key_failure_counts: + if self.key_failure_counts[key] < self.MAX_FAILURES: + return key + return self.api_keys[0] _singleton_instance = None _singleton_lock = asyncio.Lock()