From 0dd9dd53804d47d01d3debad9499202fb7511dcc Mon Sep 17 00:00:00 2001 From: snaily Date: Sat, 12 Apr 2025 21:35:38 +0800 Subject: [PATCH] =?UTF-8?q?refactor(config):=20=E5=B0=86=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E6=94=B9=E4=B8=BA=E4=BB=8E=20settings=20?= =?UTF-8?q?=E8=8E=B7=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 SecurityService, ModelService, EmbeddingService 的配置依赖从构造函数注入改为直接从 app.config.config.settings 获取。 这简化了服务类的实例化过程,并实现了配置的集中管理。 --- app/core/security.py | 17 +++++++---------- app/router/gemini_routes.py | 4 ++-- app/router/openai_routes.py | 6 +++--- app/service/embedding/embedding_service.py | 6 ++---- app/service/model/model_service.py | 20 +++++++------------- 5 files changed, 21 insertions(+), 32 deletions(-) diff --git a/app/core/security.py b/app/core/security.py index a759823..eebad69 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -13,12 +13,9 @@ def verify_auth_token(token: str) -> bool: class SecurityService: - def __init__(self, allowed_tokens: list, auth_token: str): - self.allowed_tokens = allowed_tokens - self.auth_token = auth_token async def verify_key(self, key: str): - if key not in self.allowed_tokens and key != self.auth_token: + if key not in settings.ALLOWED_TOKENS and key != settings.AUTH_TOKEN: logger.error("Invalid key") raise HTTPException(status_code=401, detail="Invalid key") return key @@ -37,7 +34,7 @@ class SecurityService: ) token = authorization.replace("Bearer ", "") - if token not in self.allowed_tokens and token != self.auth_token: + if token not in settings.ALLOWED_TOKENS and token != settings.AUTH_TOKEN: logger.error("Invalid token") raise HTTPException(status_code=401, detail="Invalid token") @@ -52,8 +49,8 @@ class SecurityService: raise HTTPException(status_code=401, detail="Missing x-goog-api-key header") if ( - x_goog_api_key not in self.allowed_tokens - and x_goog_api_key != self.auth_token + x_goog_api_key not in settings.ALLOWED_TOKENS + and x_goog_api_key != settings.AUTH_TOKEN ): logger.error("Invalid x-goog-api-key") raise HTTPException(status_code=401, detail="Invalid x-goog-api-key") @@ -67,7 +64,7 @@ class SecurityService: logger.error("Missing auth_token header") raise HTTPException(status_code=401, detail="Missing auth_token header") token = authorization.replace("Bearer ", "") - if token != self.auth_token: + if token != settings.AUTH_TOKEN: logger.error("Invalid auth_token") raise HTTPException(status_code=401, detail="Invalid auth_token") @@ -78,7 +75,7 @@ class SecurityService: ) -> str: """验证URL中的key或请求头中的x-goog-api-key""" # 如果URL中的key有效,直接返回 - if key in self.allowed_tokens or key == self.auth_token: + if key in settings.ALLOWED_TOKENS or key == settings.AUTH_TOKEN: return key # 否则检查请求头中的x-goog-api-key @@ -86,7 +83,7 @@ class SecurityService: logger.error("Invalid key and missing x-goog-api-key header") raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header") - if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token: + if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN: logger.error("Invalid key and invalid x-goog-api-key") raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key") diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index e879bb4..2206a14 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -17,8 +17,8 @@ router_v1beta = APIRouter(prefix=f"/{API_VERSION}") logger = get_gemini_logger() # 初始化服务 -security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) -model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS) +security_service = SecurityService() +model_service = ModelService() async def get_key_manager(): diff --git a/app/router/openai_routes.py b/app/router/openai_routes.py index 65de8c1..cd9baf3 100644 --- a/app/router/openai_routes.py +++ b/app/router/openai_routes.py @@ -20,9 +20,9 @@ router = APIRouter() logger = get_openai_logger() # 初始化服务 -security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) -model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS) -embedding_service = EmbeddingService(settings.BASE_URL) +security_service = SecurityService() +model_service = ModelService() +embedding_service = EmbeddingService() image_create_service = ImageCreateService() diff --git a/app/service/embedding/embedding_service.py b/app/service/embedding/embedding_service.py index 6823099..a874ae2 100644 --- a/app/service/embedding/embedding_service.py +++ b/app/service/embedding/embedding_service.py @@ -2,22 +2,20 @@ from typing import List, Union import openai from openai.types import CreateEmbeddingResponse - +from app.config.config import settings from app.log.logger import get_embeddings_logger logger = get_embeddings_logger() class EmbeddingService: - def __init__(self, base_url: str): - self.base_url = base_url async def create_embedding( self, input_text: Union[str, List[str]], model: str, api_key: str ) -> CreateEmbeddingResponse: """Create embeddings using OpenAI API""" try: - client = openai.OpenAI(api_key=api_key, base_url=self.base_url) + client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL) response = client.embeddings.create(input=input_text, model=model) return response except Exception as e: diff --git a/app/service/model/model_service.py b/app/service/model/model_service.py index 5bcc044..d220d50 100644 --- a/app/service/model/model_service.py +++ b/app/service/model/model_service.py @@ -10,14 +10,8 @@ logger = get_model_logger() class ModelService: - def __init__(self, search_models: list, image_models: list): - self.search_models = search_models - self.image_models = image_models - self.base_url = settings.BASE_URL - self.filtered_models = settings.FILTERED_MODELS - def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]: - url = f"{self.base_url}/models?key={api_key}" + url = f"{settings.BASE_URL}/models?key={api_key}" try: response = requests.get(url) @@ -27,7 +21,7 @@ class ModelService: filtered_models_list = [] for model in gemini_models.get("models", []): model_id = model["name"].split("/")[-1] - if model_id not in self.filtered_models: + if model_id not in settings.FILTERED_MODELS: filtered_models_list.append(model) else: logger.info(f"Filtered out model: {model_id}") @@ -68,11 +62,11 @@ class ModelService: } openai_format["data"].append(openai_model) - if model_id in self.search_models: + if model_id in settings.SEARCH_MODELS: search_model = openai_model.copy() search_model["id"] = f"{model_id}-search" openai_format["data"].append(search_model) - if model_id in self.image_models: + if model_id in settings.IMAGE_MODELS: image_model = openai_model.copy() image_model["id"] = f"{model_id}-image" openai_format["data"].append(image_model) @@ -90,9 +84,9 @@ class ModelService: model = model.strip() if model.endswith("-search"): model = model[:-7] - return model in self.search_models + return model in settings.SEARCH_MODELS if model.endswith("-image"): model = model[:-6] - return model in self.image_models + return model in settings.IMAGE_MODELS - return model not in self.filtered_models + return model not in settings.FILTERED_MODELS