mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-09 01:29:46 +08:00
本次提交主要包含以下更改: - 代码清理: - 移除了 `app/router/` 目录下多个路由文件 ([`config_routes.py`](app/router/config_routes.py:1), [`error_log_routes.py`](app/router/error_log_routes.py:1), [`gemini_routes.py`](app/router/gemini_routes.py:1), [`openai_compatiable_routes.py`](app/router/openai_compatiable_routes.py:1), [`openai_routes.py`](app/router/openai_routes.py:1), [`routes.py`](app/router/routes.py:1), [`scheduler_routes.py`](app/router/scheduler_routes.py:1), [`stats_routes.py`](app/router/stats_routes.py:1), [`version_routes.py`](app/router/version_routes.py:1)) 中的大量解释性注释、TODO 注释和多余的日志标记。 - 清理了 [`scheduler_routes.py`](app/router/scheduler_routes.py:31) 中被注释掉的认证逻辑。 - 这些清理旨在提高代码的整洁度和可维护性。 - UI 优化: - 在 [`app/templates/config_editor.html`](app/templates/config_editor.html:327) 中,为 Gemini 模型的安全过滤级别设置增加了一条重要的提示信息,建议用户将其设置为 "OFF" 以避免影响输出速度,并强调非必要不应随意改动。
374 lines
16 KiB
Python
374 lines
16 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
||
from fastapi.responses import StreamingResponse, JSONResponse
|
||
from copy import deepcopy
|
||
import asyncio
|
||
from app.config.config import settings
|
||
from app.log.logger import get_gemini_logger
|
||
from app.core.security import SecurityService
|
||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||
from app.service.model.model_service import ModelService
|
||
from app.handler.retry_handler import RetryHandler
|
||
from app.handler.error_handler import handle_route_errors
|
||
from app.core.constants import API_VERSION
|
||
|
||
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
||
logger = get_gemini_logger()
|
||
|
||
security_service = SecurityService()
|
||
model_service = ModelService()
|
||
|
||
|
||
async def get_key_manager():
|
||
"""获取密钥管理器实例"""
|
||
return await get_key_manager_instance()
|
||
|
||
|
||
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
||
"""获取下一个可用的API密钥"""
|
||
return await key_manager.get_next_working_key()
|
||
|
||
|
||
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||
"""获取Gemini聊天服务实例"""
|
||
return GeminiChatService(settings.BASE_URL, key_manager)
|
||
|
||
|
||
@router.get("/models")
|
||
@router_v1beta.get("/models")
|
||
async def list_models(
|
||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||
key_manager: KeyManager = Depends(get_key_manager)
|
||
):
|
||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||
operation_name = "list_gemini_models"
|
||
logger.info("-" * 50 + operation_name + "-" * 50)
|
||
logger.info("Handling Gemini models list request")
|
||
|
||
try:
|
||
api_key = await key_manager.get_first_valid_key()
|
||
if not api_key:
|
||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||
logger.info(f"Using API key: {api_key}")
|
||
|
||
models_data = await model_service.get_gemini_models(api_key)
|
||
if not models_data or "models" not in models_data:
|
||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||
|
||
models_json = deepcopy(models_data)
|
||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||
|
||
def add_derived_model(base_name, suffix, display_suffix):
|
||
model = model_mapping.get(base_name)
|
||
if not model:
|
||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||
return
|
||
item = deepcopy(model)
|
||
item["name"] = f"models/{base_name}{suffix}"
|
||
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
||
item["displayName"] = display_name
|
||
item["description"] = display_name
|
||
models_json["models"].append(item)
|
||
|
||
if settings.SEARCH_MODELS:
|
||
for name in settings.SEARCH_MODELS:
|
||
add_derived_model(name, "-search", " For Search")
|
||
if settings.IMAGE_MODELS:
|
||
for name in settings.IMAGE_MODELS:
|
||
add_derived_model(name, "-image", " For Image")
|
||
if settings.THINKING_MODELS:
|
||
for name in settings.THINKING_MODELS:
|
||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||
|
||
logger.info("Gemini models list request successful")
|
||
return models_json
|
||
except HTTPException as http_exc:
|
||
raise http_exc
|
||
except Exception as e:
|
||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||
) from e
|
||
|
||
|
||
@router.post("/models/{model_name}:generateContent")
|
||
@router_v1beta.post("/models/{model_name}:generateContent")
|
||
@RetryHandler(key_arg="api_key")
|
||
async def generate_content(
|
||
model_name: str,
|
||
request: GeminiRequest,
|
||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||
api_key: str = Depends(get_next_working_key),
|
||
key_manager: KeyManager = Depends(get_key_manager),
|
||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||
):
|
||
"""处理 Gemini 非流式内容生成请求。"""
|
||
operation_name = "gemini_generate_content"
|
||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||
logger.info(f"Using API key: {api_key}")
|
||
|
||
if not await model_service.check_model_support(model_name):
|
||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||
|
||
response = await chat_service.generate_content(
|
||
model=model_name,
|
||
request=request,
|
||
api_key=api_key
|
||
)
|
||
return response
|
||
|
||
|
||
@router.post("/models/{model_name}:streamGenerateContent")
|
||
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
||
@RetryHandler(key_arg="api_key")
|
||
async def stream_generate_content(
|
||
model_name: str,
|
||
request: GeminiRequest,
|
||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||
api_key: str = Depends(get_next_working_key),
|
||
key_manager: KeyManager = Depends(get_key_manager),
|
||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||
):
|
||
"""处理 Gemini 流式内容生成请求。"""
|
||
operation_name = "gemini_stream_generate_content"
|
||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||
logger.info(f"Using API key: {api_key}")
|
||
|
||
if not await model_service.check_model_support(model_name):
|
||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||
|
||
response_stream = chat_service.stream_generate_content(
|
||
model=model_name,
|
||
request=request,
|
||
api_key=api_key
|
||
)
|
||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||
|
||
|
||
@router.post("/reset-all-fail-counts")
|
||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
||
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
|
||
logger.info(f"Received reset request with key_type: {key_type}")
|
||
|
||
try:
|
||
# 获取分类后的密钥
|
||
keys_by_status = await key_manager.get_keys_by_status()
|
||
valid_keys = keys_by_status.get("valid_keys", {})
|
||
invalid_keys = keys_by_status.get("invalid_keys", {})
|
||
|
||
# 根据类型选择要重置的密钥
|
||
keys_to_reset = []
|
||
if key_type == "valid":
|
||
keys_to_reset = list(valid_keys.keys())
|
||
logger.info(f"Resetting only valid keys, count: {len(keys_to_reset)}")
|
||
elif key_type == "invalid":
|
||
keys_to_reset = list(invalid_keys.keys())
|
||
logger.info(f"Resetting only invalid keys, count: {len(keys_to_reset)}")
|
||
else:
|
||
# 重置所有密钥
|
||
await key_manager.reset_failure_counts()
|
||
return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
|
||
|
||
# 批量重置指定类型的密钥
|
||
for key in keys_to_reset:
|
||
await key_manager.reset_key_failure_count(key)
|
||
|
||
return JSONResponse({
|
||
"success": True,
|
||
"message": f"{key_type}密钥的失败计数已重置",
|
||
"reset_count": len(keys_to_reset)
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Failed to reset key failure counts: {str(e)}")
|
||
return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
|
||
|
||
|
||
@router.post("/reset-selected-fail-counts")
|
||
async def reset_selected_key_fail_counts(
|
||
request: ResetSelectedKeysRequest,
|
||
key_manager: KeyManager = Depends(get_key_manager)
|
||
):
|
||
"""批量重置选定Gemini API密钥的失败计数"""
|
||
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
||
keys_to_reset = request.keys
|
||
key_type = request.key_type
|
||
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
|
||
|
||
if not keys_to_reset:
|
||
return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
|
||
|
||
reset_count = 0
|
||
errors = []
|
||
|
||
try:
|
||
for key in keys_to_reset:
|
||
try:
|
||
result = await key_manager.reset_key_failure_count(key)
|
||
if result:
|
||
reset_count += 1
|
||
else:
|
||
logger.warning(f"Key not found during selective reset: {key}")
|
||
except Exception as key_error:
|
||
logger.error(f"Error resetting key {key}: {str(key_error)}")
|
||
errors.append(f"Key {key}: {str(key_error)}")
|
||
|
||
if errors:
|
||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||
final_success = reset_count > 0
|
||
status_code = 207 if final_success and errors else 500
|
||
return JSONResponse({
|
||
"success": final_success,
|
||
"message": error_message,
|
||
"reset_count": reset_count
|
||
}, status_code=status_code)
|
||
|
||
return JSONResponse({
|
||
"success": True,
|
||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||
"reset_count": reset_count
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
|
||
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
|
||
|
||
|
||
@router.post("/reset-fail-count/{api_key}")
|
||
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
|
||
"""重置指定Gemini API密钥的失败计数"""
|
||
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
|
||
logger.info(f"Resetting failure count for API key: {api_key}")
|
||
|
||
try:
|
||
result = await key_manager.reset_key_failure_count(api_key)
|
||
if result:
|
||
return JSONResponse({"success": True, "message": "失败计数已重置"})
|
||
return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
|
||
except Exception as e:
|
||
logger.error(f"Failed to reset key failure count: {str(e)}")
|
||
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
|
||
|
||
|
||
@router.post("/verify-key/{api_key}")
|
||
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
|
||
"""验证Gemini API密钥的有效性"""
|
||
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
||
logger.info("Verifying API key validity")
|
||
|
||
try:
|
||
gemini_request = GeminiRequest(
|
||
contents=[
|
||
GeminiContent(
|
||
role="user",
|
||
parts=[{"text": "hi"}],
|
||
)
|
||
],
|
||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||
)
|
||
|
||
response = await chat_service.generate_content(
|
||
settings.TEST_MODEL,
|
||
gemini_request,
|
||
api_key
|
||
)
|
||
|
||
if response:
|
||
return JSONResponse({"status": "valid"})
|
||
except Exception as e:
|
||
logger.error(f"Key verification failed: {str(e)}")
|
||
|
||
async with key_manager.failure_count_lock:
|
||
if api_key in key_manager.key_failure_counts:
|
||
key_manager.key_failure_counts[api_key] += 1
|
||
logger.warning(f"Verification exception for key: {api_key}, incrementing failure count")
|
||
|
||
return JSONResponse({"status": "invalid", "error": str(e)})
|
||
|
||
|
||
@router.post("/verify-selected-keys")
|
||
async def verify_selected_keys(
|
||
request: VerifySelectedKeysRequest,
|
||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||
key_manager: KeyManager = Depends(get_key_manager)
|
||
):
|
||
"""批量验证选定Gemini API密钥的有效性"""
|
||
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
|
||
keys_to_verify = request.keys
|
||
logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
|
||
|
||
if not keys_to_verify:
|
||
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
|
||
|
||
successful_keys = []
|
||
failed_keys = {}
|
||
|
||
async def _verify_single_key(api_key: str):
|
||
"""内部函数,用于验证单个密钥并处理异常"""
|
||
nonlocal successful_keys, failed_keys
|
||
try:
|
||
gemini_request = GeminiRequest(
|
||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||
)
|
||
await chat_service.generate_content(
|
||
settings.TEST_MODEL,
|
||
gemini_request,
|
||
api_key
|
||
)
|
||
successful_keys.append(api_key)
|
||
return api_key, "valid", None
|
||
except Exception as e:
|
||
error_message = str(e)
|
||
logger.warning(f"Key verification failed for {api_key}: {error_message}")
|
||
async with key_manager.failure_count_lock:
|
||
if api_key in key_manager.key_failure_counts:
|
||
key_manager.key_failure_counts[api_key] += 1
|
||
logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
|
||
else:
|
||
key_manager.key_failure_counts[api_key] = 1
|
||
logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
|
||
failed_keys[api_key] = error_message
|
||
return api_key, "invalid", error_message
|
||
|
||
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for result in results:
|
||
if isinstance(result, Exception):
|
||
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
|
||
elif result:
|
||
if not isinstance(result, Exception) and result:
|
||
key, status, error = result
|
||
elif isinstance(result, Exception):
|
||
logger.error(f"Task execution error during bulk verification: {result}")
|
||
|
||
valid_count = len(successful_keys)
|
||
invalid_count = len(failed_keys)
|
||
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
|
||
|
||
if failed_keys:
|
||
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
||
return JSONResponse({
|
||
"success": True,
|
||
"message": message,
|
||
"successful_keys": successful_keys,
|
||
"failed_keys": failed_keys,
|
||
"valid_count": valid_count,
|
||
"invalid_count": invalid_count
|
||
})
|
||
else:
|
||
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
||
return JSONResponse({
|
||
"success": True,
|
||
"message": message,
|
||
"successful_keys": successful_keys,
|
||
"failed_keys": {},
|
||
"valid_count": valid_count,
|
||
"invalid_count": 0
|
||
}) |