mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-07 07:22:41 +08:00
Refactor: 大幅清理代码注释并优化配置提示
本次提交主要包含以下更改: - 代码清理: - 移除了 `app/router/` 目录下多个路由文件 ([`config_routes.py`](app/router/config_routes.py:1), [`error_log_routes.py`](app/router/error_log_routes.py:1), [`gemini_routes.py`](app/router/gemini_routes.py:1), [`openai_compatiable_routes.py`](app/router/openai_compatiable_routes.py:1), [`openai_routes.py`](app/router/openai_routes.py:1), [`routes.py`](app/router/routes.py:1), [`scheduler_routes.py`](app/router/scheduler_routes.py:1), [`stats_routes.py`](app/router/stats_routes.py:1), [`version_routes.py`](app/router/version_routes.py:1)) 中的大量解释性注释、TODO 注释和多余的日志标记。 - 清理了 [`scheduler_routes.py`](app/router/scheduler_routes.py:31) 中被注释掉的认证逻辑。 - 这些清理旨在提高代码的整洁度和可维护性。 - UI 优化: - 在 [`app/templates/config_editor.html`](app/templates/config_editor.html:327) 中,为 Gemini 模型的安全过滤级别设置增加了一条重要的提示信息,建议用户将其设置为 "OFF" 以避免影响输出速度,并强调非必要不应随意改动。
This commit is contained in:
@@ -6,15 +6,13 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_config_routes_logger, Logger # 导入 Logger 类
|
||||
from app.log.logger import get_config_routes_logger, Logger
|
||||
from app.service.config.config_service import ConfigService
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
logger = get_config_routes_logger()
|
||||
|
||||
|
||||
@router.get("", response_model=Dict[str, Any])
|
||||
async def get_config(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
@@ -34,10 +32,10 @@ async def update_config(config_data: Dict[str, Any], request: Request):
|
||||
result = await ConfigService.update_config(config_data)
|
||||
# 配置更新成功后,立即更新所有 logger 的级别
|
||||
Logger.update_log_levels(config_data["LOG_LEVEL"])
|
||||
logger.info("Log levels updated after configuration change.") # 添加日志记录
|
||||
logger.info("Log levels updated after configuration change.")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating config or log levels: {e}", exc_info=True) # 记录详细错误
|
||||
logger.error(f"Error updating config or log levels: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@@ -8,34 +8,29 @@ from fastapi import APIRouter, HTTPException, Request, Query, Path, Body, Respon
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_log_routes_logger
|
||||
# 假设这些服务函数已更新或添加
|
||||
from app.database.services import (
|
||||
get_error_logs,
|
||||
get_error_logs_count,
|
||||
get_error_log_details,
|
||||
delete_error_logs_by_ids, # 新增导入
|
||||
delete_error_log_by_id # 新增导入
|
||||
delete_error_logs_by_ids,
|
||||
delete_error_log_by_id
|
||||
)
|
||||
# Removed get_db import comment as it's fully removed now
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter(prefix="/api/logs", tags=["logs"])
|
||||
|
||||
logger = get_log_routes_logger()
|
||||
|
||||
|
||||
# Define a response model that includes the total count for pagination
|
||||
# 用于列表响应的模型,假设 get_error_logs 返回包含 error_code 的字典
|
||||
class ErrorLogListItem(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_code: Optional[int] = None # 列表显示错误码 (应为整数)
|
||||
error_code: Optional[int] = None
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
class ErrorLogListResponse(BaseModel):
|
||||
logs: List[ErrorLogListItem] # 使用定义的模型列表
|
||||
logs: List[ErrorLogListItem]
|
||||
total: int
|
||||
|
||||
@router.get("/errors", response_model=ErrorLogListResponse)
|
||||
@@ -44,12 +39,12 @@ async def get_error_logs_api(
|
||||
limit: int = Query(10, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
key_search: Optional[str] = Query(None, description="Search term for Gemini key (partial match)"),
|
||||
error_search: Optional[str] = Query(None, description="Search term for error type or log message"), # 数据库查询需处理
|
||||
error_code_search: Optional[str] = Query(None, description="Search term for error code"), # Added error code search parameter
|
||||
error_search: Optional[str] = Query(None, description="Search term for error type or log message"),
|
||||
error_code_search: Optional[str] = Query(None, description="Search term for error code"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start datetime for filtering"),
|
||||
end_date: Optional[datetime] = Query(None, description="End datetime for filtering"),
|
||||
sort_by: str = Query('id', description="Field to sort by (e.g., 'id', 'request_time')"), # 新增排序参数
|
||||
sort_order: str = Query('desc', description="Sort order ('asc' or 'desc')") # 新增排序参数
|
||||
sort_by: str = Query('id', description="Field to sort by (e.g., 'id', 'request_time')"),
|
||||
sort_order: str = Query('desc', description="Sort order ('asc' or 'desc')")
|
||||
):
|
||||
"""
|
||||
获取错误日志列表 (返回错误码),支持过滤和排序
|
||||
@@ -72,32 +67,27 @@ async def get_error_logs_api(
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to error logs list")
|
||||
# API 返回 401 更合适
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
# 假设 get_error_logs 现在返回包含 error_code 的字典列表
|
||||
# 并且可以接受 include_error_code 参数 (如果需要显式指定)
|
||||
logs_data = await get_error_logs(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
key_search=key_search,
|
||||
error_search=error_search, # 数据库查询需要处理这个
|
||||
error_code_search=error_code_search, # Pass error code search to DB function
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sort_by=sort_by, # 传递排序参数
|
||||
sort_order=sort_order # 传递排序参数
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
# Fetch total count with the same search parameters
|
||||
total_count = await get_error_logs_count(
|
||||
key_search=key_search,
|
||||
error_search=error_search,
|
||||
error_code_search=error_code_search, # Pass error code search to DB count function
|
||||
error_code_search=error_code_search,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 验证并转换数据以匹配 Pydantic 模型
|
||||
validated_logs = [ErrorLogListItem(**log) for log in logs_data]
|
||||
return ErrorLogListResponse(logs=validated_logs, total=total_count)
|
||||
except Exception as e:
|
||||
@@ -105,13 +95,12 @@ async def get_error_logs_api(
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get error logs list: {str(e)}")
|
||||
|
||||
|
||||
# 新增:获取错误日志详情的路由
|
||||
class ErrorLogDetailResponse(BaseModel):
|
||||
id: int
|
||||
gemini_key: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
error_log: Optional[str] = None # 详情接口返回完整的 error_log
|
||||
request_msg: Optional[str] = None # 详情接口返回 request_msg
|
||||
error_log: Optional[str] = None
|
||||
request_msg: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
|
||||
@@ -126,27 +115,22 @@ async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
# 假设存在一个函数 get_error_log_details(log_id) 来获取完整信息
|
||||
log_details = await get_error_log_details(log_id=log_id)
|
||||
if not log_details:
|
||||
raise HTTPException(status_code=404, detail="Error log not found")
|
||||
|
||||
# 假设 get_error_log_details 返回一个字典或兼容 Pydantic 的对象
|
||||
return ErrorLogDetailResponse(**log_details)
|
||||
except HTTPException as http_exc:
|
||||
# Re-raise HTTPException (like 404)
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get error log details: {str(e)}")
|
||||
|
||||
|
||||
# 新增:批量删除错误日志
|
||||
@router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_logs_bulk_api(
|
||||
request: Request,
|
||||
payload: Dict[str, List[int]] = Body(...) # Expects {"ids": [1, 2, 3]}
|
||||
# Ensure db dependency is fully removed
|
||||
payload: Dict[str, List[int]] = Body(...)
|
||||
):
|
||||
"""
|
||||
批量删除错误日志 (异步)
|
||||
@@ -161,7 +145,6 @@ async def delete_error_logs_bulk_api(
|
||||
raise HTTPException(status_code=400, detail="No log IDs provided for deletion.")
|
||||
|
||||
try:
|
||||
# 调用异步服务函数
|
||||
deleted_count = await delete_error_logs_by_ids(log_ids)
|
||||
# 注意:异步函数返回的是尝试删除的数量,可能不是精确值
|
||||
logger.info(f"Attempted bulk deletion for {deleted_count} error logs with IDs: {log_ids}")
|
||||
@@ -171,12 +154,10 @@ async def delete_error_logs_bulk_api(
|
||||
raise HTTPException(status_code=500, detail="Internal server error during bulk deletion")
|
||||
|
||||
|
||||
# 新增:删除单个错误日志
|
||||
@router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_log_api(
|
||||
request: Request,
|
||||
log_id: int = Path(..., ge=1)
|
||||
# Ensure db dependency is fully removed
|
||||
):
|
||||
"""
|
||||
删除单个错误日志 (异步)
|
||||
@@ -187,7 +168,6 @@ async def delete_error_log_api(
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
# 调用异步服务函数
|
||||
success = await delete_error_log_by_id(log_id)
|
||||
if not success:
|
||||
# 服务层现在在未找到时返回 False,我们在这里转换为 404
|
||||
@@ -195,7 +175,7 @@ async def delete_error_log_api(
|
||||
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc # Re-raise 404 or other HTTP exceptions
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(f"Error deleting error log with ID {log_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error during deletion")
|
||||
|
||||
@@ -5,7 +5,7 @@ import asyncio
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest # 添加导入
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
@@ -13,12 +13,10 @@ from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
|
||||
# 路由设置
|
||||
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||||
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
||||
logger = get_gemini_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
|
||||
@@ -52,14 +50,14 @@ async def list_models(
|
||||
try:
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
models_data =await model_service.get_gemini_models(api_key)
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
|
||||
models_json = deepcopy(models_data) # 操作副本以防修改原始缓存
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
@@ -74,7 +72,6 @@ async def list_models(
|
||||
item["description"] = display_name
|
||||
models_json["models"].append(item)
|
||||
|
||||
# 添加衍生模型
|
||||
if settings.SEARCH_MODELS:
|
||||
for name in settings.SEARCH_MODELS:
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
@@ -88,7 +85,6 @@ async def list_models(
|
||||
logger.info("Gemini models list request successful")
|
||||
return models_json
|
||||
except HTTPException as http_exc:
|
||||
# 重新抛出已知的 HTTP 异常
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
@@ -139,7 +135,6 @@ async def stream_generate_content(
|
||||
):
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
# 流式请求的成功/失败日志在流处理中更复杂,这里仅用上下文管理器处理启动错误
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
@@ -153,7 +148,6 @@ async def stream_generate_content(
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
# 注意:流本身的错误需要在服务层或流迭代中处理,这里只返回流响应
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
|
||||
@@ -204,7 +198,7 @@ async def reset_selected_key_fail_counts(
|
||||
"""批量重置选定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
||||
keys_to_reset = request.keys
|
||||
key_type = request.key_type # 获取类型用于日志记录和响应消息
|
||||
key_type = request.key_type
|
||||
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
|
||||
|
||||
if not keys_to_reset:
|
||||
@@ -220,33 +214,27 @@ async def reset_selected_key_fail_counts(
|
||||
if result:
|
||||
reset_count += 1
|
||||
else:
|
||||
# 记录未找到的密钥,但不视为致命错误
|
||||
logger.warning(f"Key not found during selective reset: {key}")
|
||||
except Exception as key_error:
|
||||
# 记录单个密钥重置时的错误
|
||||
logger.error(f"Error resetting key {key}: {str(key_error)}")
|
||||
errors.append(f"Key {key}: {str(key_error)}")
|
||||
|
||||
if errors:
|
||||
# 如果有错误,报告部分成功或完全失败
|
||||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||||
# 确定最终状态码和成功标志
|
||||
final_success = reset_count > 0
|
||||
status_code = 207 if final_success and errors else 500 # 207 Multi-Status if partially successful, 500 if completely failed
|
||||
status_code = 207 if final_success and errors else 500
|
||||
return JSONResponse({
|
||||
"success": final_success,
|
||||
"message": error_message,
|
||||
"reset_count": reset_count
|
||||
}, status_code=status_code)
|
||||
|
||||
# 完全成功的情况
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||||
"reset_count": reset_count
|
||||
})
|
||||
except Exception as e:
|
||||
# 捕获循环外的意外错误
|
||||
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
|
||||
|
||||
@@ -274,7 +262,6 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
logger.info("Verifying API key validity")
|
||||
|
||||
try:
|
||||
# 使用generate_content接口测试key的有效性
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
@@ -296,7 +283,6 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
except Exception as e:
|
||||
logger.error(f"Key verification failed: {str(e)}")
|
||||
|
||||
# 验证出现异常时增加失败计数
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
@@ -320,80 +306,63 @@ async def verify_selected_keys(
|
||||
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
|
||||
|
||||
successful_keys = []
|
||||
failed_keys = {} # 存储失败的 key 和错误信息
|
||||
failed_keys = {}
|
||||
|
||||
async def _verify_single_key(api_key: str):
|
||||
"""内部函数,用于验证单个密钥并处理异常"""
|
||||
nonlocal successful_keys, failed_keys # 允许修改外部列表和字典
|
||||
nonlocal successful_keys, failed_keys
|
||||
try:
|
||||
# 重用单密钥验证逻辑的核心部分
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
||||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||||
)
|
||||
# 注意:这里直接调用 chat_service.generate_content,不依赖于 key_manager 获取密钥
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
)
|
||||
# 如果上面没有抛出异常,则认为密钥有效
|
||||
successful_keys.append(api_key)
|
||||
return api_key, "valid", None
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.warning(f"Key verification failed for {api_key}: {error_message}")
|
||||
# 验证失败时增加失败计数 (使用与 /verify-key 一致的逻辑)
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
|
||||
else:
|
||||
# 如果密钥不在计数中(可能刚添加或从未失败),初始化为1
|
||||
key_manager.key_failure_counts[api_key] = 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
|
||||
failed_keys[api_key] = error_message # 记录失败的 key 和错误信息
|
||||
failed_keys[api_key] = error_message
|
||||
return api_key, "invalid", error_message
|
||||
|
||||
# 并发执行所有密钥的验证
|
||||
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True) # return_exceptions=True 捕获任务本身的异常
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理并发执行的结果
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
# 捕获 asyncio.gather 可能遇到的异常(例如任务被取消)
|
||||
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
|
||||
# 可以选择如何处理这种任务级别的错误,这里我们简单记录
|
||||
# 也可以将其计入 invalid_count 或单独记录
|
||||
elif result:
|
||||
# result 可能是 (key, status, error) 或 Exception
|
||||
if not isinstance(result, Exception) and result:
|
||||
key, status, error = result
|
||||
# 失败信息已在 _verify_single_key 中记录到 failed_keys
|
||||
elif isinstance(result, Exception):
|
||||
# 记录任务本身的异常,可以关联到一个特定的 key 如果可能的话
|
||||
# 这里简化处理,只记录日志
|
||||
logger.error(f"Task execution error during bulk verification: {result}")
|
||||
|
||||
valid_count = len(successful_keys)
|
||||
invalid_count = len(failed_keys)
|
||||
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
|
||||
|
||||
# 根据是否有失败的 key 决定最终消息和状态
|
||||
if failed_keys:
|
||||
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
||||
# 即使有失败也认为是部分成功,返回 200 OK,让前端处理详细结果
|
||||
return JSONResponse({
|
||||
"success": True, # 表示请求处理完成,具体结果看内容
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": failed_keys,
|
||||
"valid_count": valid_count, # 保留计数方便前端快速展示
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count
|
||||
})
|
||||
else:
|
||||
# 完全成功
|
||||
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
|
||||
@@ -18,7 +18,6 @@ from app.service.openai_compatiable.openai_compatiable_service import OpenAIComp
|
||||
router = APIRouter()
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService()
|
||||
|
||||
async def get_key_manager():
|
||||
@@ -62,29 +61,23 @@ async def chat_completion(
|
||||
):
|
||||
"""处理聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
# 检查是否为图像生成相关的聊天模型,如果是,则使用付费密钥
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key # 保存原始key(可能是普通key)
|
||||
current_api_key = api_key
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key() # 获取付费密钥
|
||||
current_api_key = await key_manager.get_paid_key()
|
||||
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}") # 使用 current_api_key
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
|
||||
if is_image_chat:
|
||||
# 图像生成聊天,调用特定服务,不处理流式
|
||||
response = await openai_service.create_image_chat_completion(request, current_api_key)
|
||||
return response # 直接返回结果
|
||||
return response
|
||||
else:
|
||||
# 普通聊天补全
|
||||
response = await openai_service.create_chat_completion(request, current_api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
# 假设 openai_service.create_chat_completion 在流式时返回异步生成器
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
# 非流式直接返回结果
|
||||
return response
|
||||
|
||||
|
||||
@@ -98,7 +91,6 @@ async def generate_image(
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
# 强制使用配置的模型,确保请求中包含正确的模型信息
|
||||
request.model = settings.CREATE_IMAGE_MODEL
|
||||
return await openai_service.generate_images(request)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.domain.openai_models import (
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors # 导入共享错误处理器
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
@@ -20,7 +20,6 @@ from app.service.model.model_service import ModelService
|
||||
router = APIRouter()
|
||||
logger = get_openai_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
embedding_service = EmbeddingService()
|
||||
@@ -64,44 +63,35 @@ async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager), # 保留 key_manager 用于获取 paid_key
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
||||
):
|
||||
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
|
||||
operation_name = "chat_completion"
|
||||
# 检查是否为图像生成相关的聊天模型
|
||||
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
||||
current_api_key = api_key # 保存原始 key
|
||||
current_api_key = api_key
|
||||
if is_image_chat:
|
||||
current_api_key = await key_manager.get_paid_key() # 获取付费密钥
|
||||
current_api_key = await key_manager.get_paid_key()
|
||||
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
|
||||
# 检查模型支持性应在错误处理块内,以便捕获并记录错误
|
||||
if not await model_service.check_model_support(request.model):
|
||||
# 使用 HTTPException,会被 handle_route_errors 捕获并记录
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
|
||||
if is_image_chat:
|
||||
# 图像生成聊天
|
||||
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
# 非流式直接返回结果
|
||||
return response
|
||||
else:
|
||||
# 普通聊天补全
|
||||
response = await chat_service.create_chat_completion(request, current_api_key)
|
||||
# 处理流式响应
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
# 非流式直接返回结果
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from app.service.stats.stats_service import StatsService
|
||||
|
||||
logger = get_routes_logger()
|
||||
|
||||
# 配置Jinja2模板
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
@@ -25,7 +24,6 @@ def setup_routers(app: FastAPI) -> None:
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
# 包含API路由
|
||||
app.include_router(openai_routes.router)
|
||||
app.include_router(gemini_routes.router)
|
||||
app.include_router(gemini_routes.router_v1beta)
|
||||
@@ -36,12 +34,10 @@ def setup_routers(app: FastAPI) -> None:
|
||||
app.include_router(version_routes.router)
|
||||
app.include_router(openai_compatiable_routes.router)
|
||||
|
||||
# 添加页面路由
|
||||
setup_page_routes(app)
|
||||
|
||||
# 添加健康检查路由
|
||||
setup_health_routes(app)
|
||||
setup_api_stats_routes(app) # Add API stats routes
|
||||
setup_api_stats_routes(app)
|
||||
|
||||
|
||||
def setup_page_routes(app: FastAPI) -> None:
|
||||
@@ -106,16 +102,14 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
"request": request,
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
"total_keys": total_keys, # Renamed for clarity
|
||||
"valid_key_count": valid_key_count, # Added count
|
||||
"invalid_key_count": invalid_key_count, # Added count
|
||||
"api_stats": api_stats, # <-- Pass stats to template
|
||||
"total_keys": total_keys,
|
||||
"valid_key_count": valid_key_count,
|
||||
"invalid_key_count": invalid_key_count,
|
||||
"api_stats": api_stats,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving keys status or API stats: {str(e)}")
|
||||
# Optionally, render template with error or default stats
|
||||
# For now, re-raise to show error page
|
||||
raise
|
||||
|
||||
@app.get("/config", response_class=HTMLResponse)
|
||||
@@ -175,16 +169,13 @@ def setup_api_stats_routes(app: FastAPI) -> None:
|
||||
async def api_stats_details(request: Request, period: str):
|
||||
"""获取指定时间段内的 API 调用详情"""
|
||||
try:
|
||||
# 验证认证
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to API stats details")
|
||||
# Returning JSON error instead of redirect for API endpoint
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
logger.info(f"Fetching API call details for period: {period}")
|
||||
# Use the service instance here as well
|
||||
stats_service = StatsService() # Create an instance
|
||||
stats_service = StatsService()
|
||||
details = await stats_service.get_api_call_details(period)
|
||||
return details
|
||||
except ValueError as e:
|
||||
|
||||
@@ -2,22 +2,20 @@
|
||||
定时任务控制路由模块
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, status # 移除 Depends, 添加 Request
|
||||
from fastapi import APIRouter, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.security import verify_auth_token # 导入 verify_auth_token
|
||||
from app.core.security import verify_auth_token
|
||||
from app.scheduler.key_checker import start_scheduler, stop_scheduler
|
||||
from app.log.logger import get_scheduler_routes # 使用路由日志记录器
|
||||
from app.log.logger import get_scheduler_routes
|
||||
|
||||
logger = get_scheduler_routes()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/scheduler",
|
||||
tags=["Scheduler"]
|
||||
# 移除全局依赖
|
||||
)
|
||||
|
||||
# 认证检查的辅助函数
|
||||
async def verify_token(request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
@@ -29,14 +27,12 @@ async def verify_token(request: Request):
|
||||
)
|
||||
|
||||
@router.post("/start", summary="启动定时任务")
|
||||
async def start_scheduler_endpoint(request: Request): # 添加 request 参数
|
||||
async def start_scheduler_endpoint(request: Request):
|
||||
"""Start the background scheduler task"""
|
||||
"""
|
||||
await verify_token(request) # 在函数开始处进行认证检查
|
||||
"""
|
||||
await verify_token(request)
|
||||
try:
|
||||
logger.info("Received request to start scheduler.")
|
||||
start_scheduler() # 调用 key_checker 中的函数
|
||||
start_scheduler()
|
||||
return JSONResponse(content={"message": "Scheduler started successfully."}, status_code=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting scheduler: {str(e)}", exc_info=True)
|
||||
@@ -46,14 +42,12 @@ async def start_scheduler_endpoint(request: Request): # 添加 request 参数
|
||||
)
|
||||
|
||||
@router.post("/stop", summary="停止定时任务")
|
||||
async def stop_scheduler_endpoint(request: Request): # 添加 request 参数
|
||||
async def stop_scheduler_endpoint(request: Request):
|
||||
"""Stop the background scheduler task"""
|
||||
"""
|
||||
await verify_token(request) # 在函数开始处进行认证检查
|
||||
"""
|
||||
await verify_token(request)
|
||||
try:
|
||||
logger.info("Received request to stop scheduler.")
|
||||
stop_scheduler() # 调用 key_checker 中的函数
|
||||
stop_scheduler()
|
||||
return JSONResponse(content={"message": "Scheduler stopped successfully."}, status_code=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping scheduler: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -45,9 +45,6 @@ async def get_key_usage_details(key: str):
|
||||
try:
|
||||
usage_details = await stats_service.get_key_usage_details_last_24h(key)
|
||||
if usage_details is None:
|
||||
# Handle case where key might be valid but has no recent usage,
|
||||
# or if the service layer explicitly returns None for other reasons.
|
||||
# Returning an empty dict is usually fine for the frontend.
|
||||
return {}
|
||||
return usage_details
|
||||
except Exception as e:
|
||||
|
||||
@@ -21,10 +21,9 @@ async def get_version_info():
|
||||
检查当前应用程序版本与最新的 GitHub release 版本。
|
||||
"""
|
||||
try:
|
||||
current_version = get_current_version() # Use imported function
|
||||
current_version = get_current_version()
|
||||
update_available, latest_version, error_message = await check_for_updates()
|
||||
|
||||
# Log the result for debugging
|
||||
logger.info(f"Version check API result: current={current_version}, latest={latest_version}, available={update_available}, error='{error_message}'")
|
||||
|
||||
return VersionInfo(
|
||||
|
||||
@@ -324,6 +324,7 @@
|
||||
</button>
|
||||
</div>
|
||||
<small class="text-gray-500 mt-1 block">配置模型的安全过滤级别,例如 HARM_CATEGORY_HARASSMENT: BLOCK_NONE。</small>
|
||||
<small class="text-red-500 mt-1 block">建议设置成OFF,其他值会影响输出速度,非必要不要随便改动。</small>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user