From 632dee38b3cc6866d56c970b5cc132ffad7a0705 Mon Sep 17 00:00:00 2001 From: Toddy <167494546+toddyoe@users.noreply.github.com> Date: Fri, 14 Mar 2025 04:11:21 +0000 Subject: [PATCH] check model before send request --- app/api/gemini_routes.py | 6 ++++++ app/api/openai_routes.py | 4 ++++ app/services/chat/api_client.py | 16 ++++++++++------ app/services/model_service.py | 14 ++++++++++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index 9b939d1..31d55ed 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -90,6 +90,9 @@ async def generate_content( logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Using API key: {api_key}") + if not model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + try: response = await chat_service.generate_content( model=model_name, @@ -120,6 +123,9 @@ async def stream_generate_content( logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Using API key: {api_key}") + if not model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + try: response_stream = chat_service.stream_generate_content( model=model_name, diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index e70b4a6..bf3d464 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -61,6 +61,10 @@ async def chat_completion( logger.info(f"Handling chat completion request for model: {request.model}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Using API key: {api_key}") + + if not model_service.check_model_support(request.model): + raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported") + try: # 如果model是imagen3,使用paid_key if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": diff --git a/app/services/chat/api_client.py b/app/services/chat/api_client.py index 285a5dd..9469395 100644 --- a/app/services/chat/api_client.py +++ b/app/services/chat/api_client.py @@ -24,12 +24,18 @@ class GeminiApiClient(ApiClient): self.base_url = base_url self.timeout = timeout - async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: - timeout = httpx.Timeout(self.timeout, read=self.timeout) + def _get_real_model(self, model: str) -> str: if model.endswith("-search"): model = model[:-7] if model.endswith("-image"): model = model[:-6] + + return model + + async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + model = self._get_real_model(model) + async with httpx.AsyncClient(timeout=timeout) as client: url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" response = await client.post(url, json=payload) @@ -40,10 +46,8 @@ class GeminiApiClient(ApiClient): async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: timeout = httpx.Timeout(self.timeout, read=self.timeout) - if model.endswith("-search"): - model = model[:-7] - if model.endswith("-image"): - model = model[:-6] + model = self._get_real_model(model) + async with httpx.AsyncClient(timeout=timeout) as client: url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" async with client.stream(method="POST", url=url, json=payload) as response: diff --git a/app/services/model_service.py b/app/services/model_service.py index 9239dfc..9befb9d 100644 --- a/app/services/model_service.py +++ b/app/services/model_service.py @@ -68,3 +68,17 @@ class ModelService: image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat" openai_format["data"].append(image_model) return openai_format + + def check_model_support(self, model: str) -> bool: + if not model or not isinstance(model, str): + return False + + model = model.strip() + if model.endswith("-search"): + model = model[:-7] + return model in settings.MODEL_SEARCH + if model.endswith("-image"): + model = model[:-6] + return model in settings.MODEL_IMAGE + + return True \ No newline at end of file