mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 14:21:27 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb5cd92041 | ||
|
|
0be85e9536 | ||
|
|
632dee38b3 | ||
|
|
16c28bf1ba |
@@ -10,7 +10,7 @@ COPY ./app /app/app
|
||||
ENV API_KEYS='["your_api_key_1"]'
|
||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
ENV TOOLS_CODE_EXECUTION_ENABLED=fasle
|
||||
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
ENV MODEL_SEARCH='["gemini-2.0-flash-exp"]'
|
||||
|
||||
# Expose port
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
|
||||
from copy import deepcopy
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
@@ -36,18 +36,40 @@ async def list_models(_=Depends(security_service.verify_key),
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
models_json = model_service.get_gemini_models(api_key)
|
||||
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
|
||||
"displayName": "Gemini 2.0 Flash Search Experimental",
|
||||
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
|
||||
"outputTokenLimit": 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
|
||||
"topP": 0.95, "topK": 64, "maxTemperature": 2})
|
||||
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-image", "version": "2.0",
|
||||
"displayName": "Gemini 2.0 Flash Image Experimental",
|
||||
"description": "Gemini 2.0 Flash Image Experimental", "inputTokenLimit": 32767,
|
||||
"outputTokenLimit": 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
|
||||
"topP": 0.95, "topK": 64, "maxTemperature": 2})
|
||||
|
||||
# 模型名称以及对应的详细信息
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
||||
|
||||
# 添加搜索模型
|
||||
if settings.MODEL_SEARCH:
|
||||
for name in settings.MODEL_SEARCH:
|
||||
model = model_mapping.get(name, None)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-search"
|
||||
display_name = f'{item.get("displayName")} For Search'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
# 添加图像生成模型
|
||||
if settings.MODEL_IMAGE:
|
||||
for name in settings.MODEL_IMAGE:
|
||||
model = model_mapping.get(name, None)
|
||||
if not model:
|
||||
continue
|
||||
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{name}-image"
|
||||
display_name = f'{item.get("displayName")} For Image'
|
||||
item["displayName"] = display_name
|
||||
item["description"] = display_name
|
||||
|
||||
models_json["models"].append(item)
|
||||
|
||||
return models_json
|
||||
|
||||
|
||||
@@ -68,6 +90,9 @@ async def generate_content(
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
@@ -98,6 +123,9 @@ async def stream_generate_content(
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
try:
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
|
||||
@@ -61,6 +61,10 @@ async def chat_completion(
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(request.model):
|
||||
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
|
||||
|
||||
try:
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
|
||||
@@ -24,12 +24,18 @@ class GeminiApiClient(ApiClient):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
def _get_real_model(self, model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
|
||||
return model
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
@@ -40,10 +46,8 @@ class GeminiApiClient(ApiClient):
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
model = self._get_real_model(model)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
|
||||
@@ -34,7 +34,7 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
system_instruction = None
|
||||
system_instruction_parts = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
@@ -64,8 +64,16 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
|
||||
if parts:
|
||||
if role == "system":
|
||||
system_instruction = {"role": "system", "parts": parts}
|
||||
system_instruction_parts.extend(parts)
|
||||
else:
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
system_instruction = (
|
||||
None
|
||||
if not system_instruction_parts
|
||||
else {
|
||||
"role": "system",
|
||||
"parts": system_instruction_parts,
|
||||
}
|
||||
)
|
||||
return converted_messages, system_instruction
|
||||
|
||||
@@ -68,3 +68,17 @@ class ModelService:
|
||||
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
openai_format["data"].append(image_model)
|
||||
return openai_format
|
||||
|
||||
def check_model_support(self, model: str) -> bool:
|
||||
if not model or not isinstance(model, str):
|
||||
return False
|
||||
|
||||
model = model.strip()
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
return model in settings.MODEL_SEARCH
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
return model in settings.MODEL_IMAGE
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user