mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-16 09:37:36 +08:00
142 lines
5.8 KiB
Python
142 lines
5.8 KiB
Python
from fastapi import HTTPException, APIRouter, Depends
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.core.config import settings
|
|
from app.core.logger import get_openai_logger
|
|
from app.core.security import SecurityService
|
|
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
|
|
from app.services.chat.retry_handler import RetryHandler
|
|
from app.services.embedding_service import EmbeddingService
|
|
from app.services.image_create_service import ImageCreateService
|
|
from app.services.key_manager import KeyManager, get_key_manager_instance
|
|
from app.services.model_service import ModelService
|
|
from app.services.openai_chat_service import OpenAIChatService
|
|
|
|
router = APIRouter()
|
|
logger = get_openai_logger()
|
|
|
|
# 初始化服务
|
|
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
|
model_service = ModelService(settings.MODEL_SEARCH)
|
|
embedding_service = EmbeddingService(settings.BASE_URL)
|
|
image_create_service = ImageCreateService()
|
|
|
|
async def get_key_manager():
|
|
return await get_key_manager_instance()
|
|
|
|
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
|
|
return await key_manager.get_next_working_key()
|
|
|
|
@router.get("/v1/models")
|
|
@router.get("/hf/v1/models")
|
|
async def list_models(
|
|
_=Depends(security_service.verify_authorization),
|
|
key_manager: KeyManager = Depends(get_key_manager)
|
|
):
|
|
logger.info("-" * 50 + "list_models" + "-" * 50)
|
|
logger.info("Handling models list request")
|
|
api_key = await key_manager.get_next_working_key()
|
|
logger.info(f"Using API key: {api_key}")
|
|
try:
|
|
return model_service.get_gemini_openai_models(api_key)
|
|
except Exception as e:
|
|
logger.error(f"Error getting models list: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
|
|
|
|
|
|
@router.post("/v1/chat/completions")
|
|
@router.post("/hf/v1/chat/completions")
|
|
@RetryHandler(max_retries=3, key_arg="api_key")
|
|
async def chat_completion(
|
|
request: ChatRequest,
|
|
_=Depends(security_service.verify_authorization),
|
|
api_key: str = Depends(get_next_working_key_wrapper),
|
|
key_manager: KeyManager = Depends(get_key_manager)
|
|
):
|
|
# 如果model是imagen3,使用paid_key
|
|
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
|
api_key = await key_manager.get_paid_key()
|
|
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
|
|
logger.info("-" * 50 + "chat_completion" + "-" * 50)
|
|
logger.info(f"Handling chat completion request for model: {request.model}")
|
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
|
logger.info(f"Using API key: {api_key}")
|
|
try:
|
|
# 如果model是imagen3,使用paid_key
|
|
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
|
response = await chat_service.create_image_chat_completion(request=request)
|
|
else:
|
|
response = await chat_service.create_chat_completion(request, api_key)
|
|
# 处理流式响应
|
|
if request.stream:
|
|
return StreamingResponse(response, media_type="text/event-stream")
|
|
logger.info("Chat completion request successful")
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Chat completion failed after retries: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
|
|
|
@router.post("/v1/images/generations")
|
|
@router.post("/hf/v1/images/generations")
|
|
async def generate_image(
|
|
request: ImageGenerationRequest,
|
|
_=Depends(security_service.verify_authorization),
|
|
):
|
|
logger.info("-" * 50 + "generate_image" + "-" * 50)
|
|
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
|
|
|
try:
|
|
response = image_create_service.generate_images(request)
|
|
logger.info("Image generation request successful")
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Image generation request failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Image generation request failed") from e
|
|
|
|
@router.post("/v1/embeddings")
|
|
@router.post("/hf/v1/embeddings")
|
|
async def embedding(
|
|
request: EmbeddingRequest,
|
|
_=Depends(security_service.verify_authorization),
|
|
key_manager: KeyManager = Depends(get_key_manager)
|
|
):
|
|
logger.info("-" * 50 + "embedding" + "-" * 50)
|
|
logger.info(f"Handling embedding request for model: {request.model}")
|
|
api_key = await key_manager.get_next_working_key()
|
|
logger.info(f"Using API key: {api_key}")
|
|
try:
|
|
response = await embedding_service.create_embedding(
|
|
input_text=request.input, model=request.model, api_key=api_key
|
|
)
|
|
logger.info("Embedding request successful")
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Embedding request failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Embedding request failed") from e
|
|
|
|
@router.get("/v1/keys/list")
|
|
@router.get("/hf/v1/keys/list")
|
|
async def get_keys_list(
|
|
_=Depends(security_service.verify_auth_token),
|
|
key_manager: KeyManager = Depends(get_key_manager)
|
|
):
|
|
"""获取有效和无效的API key列表"""
|
|
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
|
logger.info("Handling keys list request")
|
|
try:
|
|
keys_status = await key_manager.get_keys_by_status()
|
|
return {
|
|
"status": "success",
|
|
"data": {
|
|
"valid_keys": keys_status["valid_keys"],
|
|
"invalid_keys": keys_status["invalid_keys"]
|
|
},
|
|
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting keys list: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Internal server error while fetching keys list"
|
|
) from e
|