mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-29 03:31:40 +08:00
check model before send request
This commit is contained in:
@@ -90,6 +90,9 @@ async def generate_content(
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
@@ -120,6 +123,9 @@ async def stream_generate_content(
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
|
||||
@@ -61,6 +61,10 @@ async def chat_completion(
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(request.model):
|
||||
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
|
||||
|
||||
try:
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
|
||||
@@ -24,12 +24,18 @@ class GeminiApiClient(ApiClient):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
def _get_real_model(self, model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
|
||||
return model
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
@@ -40,10 +46,8 @@ class GeminiApiClient(ApiClient):
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
|
||||
@@ -68,3 +68,17 @@ class ModelService:
|
||||
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
openai_format["data"].append(image_model)
|
||||
return openai_format
|
||||
|
||||
def check_model_support(self, model: str) -> bool:
|
||||
if not model or not isinstance(model, str):
|
||||
return False
|
||||
|
||||
model = model.strip()
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
return model in settings.MODEL_SEARCH
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
return model in settings.MODEL_IMAGE
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user