diff --git a/.env.example b/.env.example index cad0338..de40df7 100644 --- a/.env.example +++ b/.env.example @@ -79,4 +79,12 @@ URL_NORMALIZATION_ENABLED=false # tts配置 TTS_MODEL=gemini-2.5-flash-preview-tts TTS_VOICE_NAME=Zephyr -TTS_SPEED=normal \ No newline at end of file +TTS_SPEED=normal +#########################Files API 相关配置######################## +# 是否启用文件过期自动清理 +FILES_CLEANUP_ENABLED=true +# 文件过期清理间隔(小时) +FILES_CLEANUP_INTERVAL_HOURS=1 +# 是否启用用户文件隔离(每个用户只能看到自己上传的文件) +FILES_USER_ISOLATION_ENABLED=true +########################################################################## \ No newline at end of file diff --git a/app/config/config.py b/app/config/config.py index 5d437f1..21c4964 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -123,6 +123,10 @@ class Settings(BaseSettings): AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30 SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS + # Files API + FILES_CLEANUP_ENABLED: bool = True + FILES_CLEANUP_INTERVAL_HOURS: int = 1 + FILES_USER_ISOLATION_ENABLED: bool = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/app/database/initialization.py b/app/database/initialization.py index a9e8863..d6003f3 100644 --- a/app/database/initialization.py +++ b/app/database/initialization.py @@ -7,7 +7,7 @@ from sqlalchemy import inspect from sqlalchemy.orm import Session from app.database.connection import engine, Base -from app.database.models import Settings +from app.database.models import Settings, FileRecord from app.log.logger import get_database_logger logger = get_database_logger() diff --git a/app/database/models.py b/app/database/models.py index c33ae6a..478b4d7 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -2,7 +2,8 @@ 数据库模型模块 """ import datetime -from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean +from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum +import enum from app.database.connection import Base @@ -60,3 +61,69 @@ class RequestLog(Base): def __repr__(self): return f"" + + +class FileState(enum.Enum): + """文件状态枚举""" + PROCESSING = "PROCESSING" + ACTIVE = "ACTIVE" + FAILED = "FAILED" + + +class FileRecord(Base): + """ + 文件记录表,用于存储上传到 Gemini 的文件信息 + """ + __tablename__ = "t_file_records" + + id = Column(Integer, primary_key=True, autoincrement=True) + + # 文件基本信息 + name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}") + display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名") + mime_type = Column(String(100), nullable=False, comment="MIME 类型") + size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)") + sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值") + + # 状态信息 + state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态") + + # 时间戳 + create_time = Column(DateTime, nullable=False, comment="创建时间") + update_time = Column(DateTime, nullable=False, comment="更新时间") + expiration_time = Column(DateTime, nullable=False, comment="过期时间") + + # API 相关 + uri = Column(String(500), nullable=False, comment="文件访问 URI") + api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key") + upload_url = Column(Text, nullable=True, comment="临时上传 URL(用于分块上传)") + + # 额外信息 + user_token = Column(String(100), nullable=True, comment="上传用户的 token") + upload_completed = Column(DateTime, nullable=True, comment="上传完成时间") + + def __repr__(self): + return f"" + + def to_dict(self): + """转换为字典格式,用于 API 响应""" + return { + "name": self.name, + "displayName": self.display_name, + "mimeType": self.mime_type, + "sizeBytes": str(self.size_bytes), + "createTime": self.create_time.isoformat() + "Z", + "updateTime": self.update_time.isoformat() + "Z", + "expirationTime": self.expiration_time.isoformat() + "Z", + "sha256Hash": self.sha256_hash, + "uri": self.uri, + "state": self.state.value if self.state else "PROCESSING" + } + + def is_expired(self): + """检查文件是否已过期""" + # 确保比较时都是 timezone-aware + expiration_time = self.expiration_time + if expiration_time.tzinfo is None: + expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc) + return datetime.datetime.now(datetime.timezone.utc) > expiration_time diff --git a/app/database/services.py b/app/database/services.py index 893b608..194f814 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -2,11 +2,11 @@ 数据库服务模块 """ from typing import List, Optional, Dict, Any, Union -from datetime import datetime +from datetime import datetime, timedelta, timezone from sqlalchemy import func, desc, asc, select, insert, update, delete import json from app.database.connection import database -from app.database.models import Settings, ErrorLog, RequestLog +from app.database.models import Settings, ErrorLog, RequestLog, FileRecord, FileState from app.log.logger import get_database_logger logger = get_database_logger() @@ -427,3 +427,264 @@ async def add_request_log( except Exception as e: logger.error(f"Failed to add request log: {str(e)}") return False + + +# ==================== 文件记录相关函数 ==================== + +async def create_file_record( + name: str, + mime_type: str, + size_bytes: int, + api_key: str, + uri: str, + create_time: datetime, + update_time: datetime, + expiration_time: datetime, + state: FileState = FileState.PROCESSING, + display_name: Optional[str] = None, + sha256_hash: Optional[str] = None, + upload_url: Optional[str] = None, + user_token: Optional[str] = None +) -> Dict[str, Any]: + """ + 创建文件记录 + + Args: + name: 文件名称(格式: files/{file_id}) + mime_type: MIME 类型 + size_bytes: 文件大小(字节) + api_key: 上传时使用的 API Key + uri: 文件访问 URI + create_time: 创建时间 + update_time: 更新时间 + expiration_time: 过期时间 + display_name: 显示名称 + sha256_hash: SHA256 哈希值 + upload_url: 临时上传 URL + user_token: 上传用户的 token + + Returns: + Dict[str, Any]: 创建的文件记录 + """ + try: + query = insert(FileRecord).values( + name=name, + display_name=display_name, + mime_type=mime_type, + size_bytes=size_bytes, + sha256_hash=sha256_hash, + state=state, + create_time=create_time, + update_time=update_time, + expiration_time=expiration_time, + uri=uri, + api_key=api_key, + upload_url=upload_url, + user_token=user_token + ) + await database.execute(query) + + # 返回创建的记录 + return await get_file_record_by_name(name) + except Exception as e: + logger.error(f"Failed to create file record: {str(e)}") + raise + + +async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]: + """ + 根据文件名获取文件记录 + + Args: + name: 文件名称(格式: files/{file_id}) + + Returns: + Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None + """ + try: + query = select(FileRecord).where(FileRecord.name == name) + result = await database.fetch_one(query) + return dict(result) if result else None + except Exception as e: + logger.error(f"Failed to get file record by name {name}: {str(e)}") + raise + + + +async def update_file_record_state( + file_name: str, + state: FileState, + update_time: Optional[datetime] = None, + upload_completed: Optional[datetime] = None, + sha256_hash: Optional[str] = None +) -> bool: + """ + 更新文件记录状态 + + Args: + file_name: 文件名 + state: 新状态 + update_time: 更新时间 + upload_completed: 上传完成时间 + sha256_hash: SHA256 哈希值 + + Returns: + bool: 是否更新成功 + """ + try: + values = {"state": state} + if update_time: + values["update_time"] = update_time + if upload_completed: + values["upload_completed"] = upload_completed + if sha256_hash: + values["sha256_hash"] = sha256_hash + + query = update(FileRecord).where(FileRecord.name == file_name).values(**values) + result = await database.execute(query) + + if result: + logger.info(f"Updated file record state for {file_name} to {state}") + return True + + logger.warning(f"File record not found for update: {file_name}") + return False + except Exception as e: + logger.error(f"Failed to update file record state: {str(e)}") + return False + + +async def list_file_records( + user_token: Optional[str] = None, + api_key: Optional[str] = None, + page_size: int = 10, + page_token: Optional[str] = None +) -> tuple[List[Dict[str, Any]], Optional[str]]: + """ + 列出文件记录 + + Args: + user_token: 用户 token(如果提供,只返回该用户的文件) + api_key: API Key(如果提供,只返回使用该 key 的文件) + page_size: 每页大小 + page_token: 分页标记(偏移量) + + Returns: + tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记) + """ + try: + logger.debug(f"list_file_records called with page_size={page_size}, page_token={page_token}") + query = select(FileRecord).where( + FileRecord.expiration_time > datetime.now(timezone.utc) + ) + + if user_token: + query = query.where(FileRecord.user_token == user_token) + if api_key: + query = query.where(FileRecord.api_key == api_key) + + # 使用偏移量进行分页 + offset = 0 + if page_token: + try: + offset = int(page_token) + except ValueError: + logger.warning(f"Invalid page token: {page_token}") + offset = 0 + + # 按ID升序排列,使用 OFFSET 和 LIMIT + query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1) + + results = await database.fetch_all(query) + + logger.debug(f"Query returned {len(results)} records") + if results: + logger.debug(f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}") + + # 处理分页 + has_next = len(results) > page_size + if has_next: + results = results[:page_size] + # 下一页的偏移量是当前偏移量加上本页返回的记录数 + next_offset = offset + page_size + next_page_token = str(next_offset) + logger.debug(f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}") + else: + next_page_token = None + logger.debug(f"No next page, returning {len(results)} results") + + return [dict(row) for row in results], next_page_token + except Exception as e: + logger.error(f"Failed to list file records: {str(e)}") + raise + + +async def delete_file_record(name: str) -> bool: + """ + 删除文件记录 + + Args: + name: 文件名称 + + Returns: + bool: 是否删除成功 + """ + try: + query = delete(FileRecord).where(FileRecord.name == name) + await database.execute(query) + return True + except Exception as e: + logger.error(f"Failed to delete file record: {str(e)}") + return False + + +async def delete_expired_file_records() -> List[Dict[str, Any]]: + """ + 删除已过期的文件记录 + + Returns: + List[Dict[str, Any]]: 删除的记录列表 + """ + try: + # 先获取要删除的记录 + query = select(FileRecord).where( + FileRecord.expiration_time <= datetime.now(datetime.timezone.utc) + ) + expired_records = await database.fetch_all(query) + + if not expired_records: + return [] + + # 执行删除 + delete_query = delete(FileRecord).where( + FileRecord.expiration_time <= datetime.now(datetime.timezone.utc) + ) + await database.execute(delete_query) + + logger.info(f"Deleted {len(expired_records)} expired file records") + return [dict(record) for record in expired_records] + except Exception as e: + logger.error(f"Failed to delete expired file records: {str(e)}") + raise + + +async def get_file_api_key(name: str) -> Optional[str]: + """ + 获取文件对应的 API Key + + Args: + name: 文件名称 + + Returns: + Optional[str]: API Key,如果文件不存在或已过期则返回 None + """ + try: + query = select(FileRecord.api_key).where( + (FileRecord.name == name) & + (FileRecord.expiration_time > datetime.now(timezone.utc)) + ) + result = await database.fetch_one(query) + return result["api_key"] if result else None + except Exception as e: + logger.error(f"Failed to get file API key: {str(e)}") + raise diff --git a/app/domain/file_models.py b/app/domain/file_models.py new file mode 100644 index 0000000..b5cd2e4 --- /dev/null +++ b/app/domain/file_models.py @@ -0,0 +1,69 @@ +""" +Files API 相关的领域模型 +""" +from typing import Optional, Dict, Any, List +from datetime import datetime +from pydantic import BaseModel, Field + + +class FileUploadConfig(BaseModel): + """文件上传配置""" + mime_type: Optional[str] = Field(None, description="MIME 类型") + display_name: Optional[str] = Field(None, description="显示名称,最多40个字符") + + +class CreateFileRequest(BaseModel): + """创建文件请求(用于初始化上传)""" + file: Optional[Dict[str, Any]] = Field(None, description="文件元数据") + + +class FileMetadata(BaseModel): + """文件元数据响应""" + name: str = Field(..., description="文件名称,格式: files/{file_id}") + displayName: Optional[str] = Field(None, description="显示名称") + mimeType: str = Field(..., description="MIME 类型") + sizeBytes: str = Field(..., description="文件大小(字节)") + createTime: str = Field(..., description="创建时间 (RFC3339)") + updateTime: str = Field(..., description="更新时间 (RFC3339)") + expirationTime: str = Field(..., description="过期时间 (RFC3339)") + sha256Hash: Optional[str] = Field(None, description="SHA256 哈希值") + uri: str = Field(..., description="文件访问 URI") + state: str = Field(..., description="文件状态") + + class Config: + json_encoders = { + datetime: lambda v: v.isoformat() + "Z" + } + + +class ListFilesRequest(BaseModel): + """列出文件请求参数""" + pageSize: Optional[int] = Field(10, ge=1, le=100, description="每页大小") + pageToken: Optional[str] = Field(None, description="分页标记") + + +class ListFilesResponse(BaseModel): + """列出文件响应""" + files: List[FileMetadata] = Field(default_factory=list, description="文件列表") + nextPageToken: Optional[str] = Field(None, description="下一页标记") + + +class UploadInitResponse(BaseModel): + """上传初始化响应(内部使用)""" + file_metadata: FileMetadata + upload_url: str + + +class FileKeyMapping(BaseModel): + """文件与 API Key 的映射关系(内部使用)""" + file_name: str + api_key: str + user_token: str + created_at: datetime + expires_at: datetime + + +class DeleteFileResponse(BaseModel): + """删除文件响应""" + success: bool = Field(..., description="是否删除成功") + message: Optional[str] = Field(None, description="消息") \ No newline at end of file diff --git a/app/exception/exceptions.py b/app/exception/exceptions.py index 0e9fb30..8cff4ae 100644 --- a/app/exception/exceptions.py +++ b/app/exception/exceptions.py @@ -76,6 +76,8 @@ class ServiceUnavailableError(APIError): ) + + def setup_exception_handlers(app: FastAPI) -> None: """ 设置应用程序的异常处理器 diff --git a/app/log/logger.py b/app/log/logger.py index 1614a46..b6472a5 100644 --- a/app/log/logger.py +++ b/app/log/logger.py @@ -228,6 +228,10 @@ def get_request_log_logger(): return Logger.setup_logger("request_log") +def get_files_logger(): + return Logger.setup_logger("files") + + def get_vertex_express_logger(): return Logger.setup_logger("vertex_express") diff --git a/app/middleware/middleware.py b/app/middleware/middleware.py index 85d512f..7cdfcd1 100644 --- a/app/middleware/middleware.py +++ b/app/middleware/middleware.py @@ -34,6 +34,7 @@ class AuthMiddleware(BaseHTTPMiddleware): and not request.url.path.startswith("/openai") and not request.url.path.startswith("/api/version/check") and not request.url.path.startswith("/vertex-express") + and not request.url.path.startswith("/upload") ): auth_token = request.cookies.get("auth_token") diff --git a/app/router/files_routes.py b/app/router/files_routes.py new file mode 100644 index 0000000..234b1f7 --- /dev/null +++ b/app/router/files_routes.py @@ -0,0 +1,296 @@ +""" +Files API 路由 +""" +from typing import Optional +from fastapi import APIRouter, Request, Response, Query, Depends, Body, Header, HTTPException +from fastapi.responses import JSONResponse + +from app.config.config import settings +from app.domain.file_models import ( + FileMetadata, + ListFilesResponse, + CreateFileRequest, + DeleteFileResponse +) +from app.log.logger import get_files_logger +from app.core.security import SecurityService +from app.service.files.files_service import get_files_service +from app.service.files.file_upload_handler import get_upload_handler + +logger = get_files_logger() + +router = APIRouter() +security_service = SecurityService() + + +@router.post("/upload/v1beta/files") +async def upload_file_init( + request: Request, + auth_token: str = Depends(security_service.verify_key_or_goog_api_key), + x_goog_upload_protocol: Optional[str] = Header(None), + x_goog_upload_command: Optional[str] = Header(None), + x_goog_upload_header_content_length: Optional[str] = Header(None), + x_goog_upload_header_content_type: Optional[str] = Header(None), +): + """初始化文件上传""" + logger.debug(f"Upload file request: {request.method=}, {request.url=}, {auth_token=}, {x_goog_upload_protocol=}, {x_goog_upload_command=}, {x_goog_upload_header_content_length=}, {x_goog_upload_header_content_type=}") + + # 檢查是否是實際的上傳請求(有 upload_id) + if request.query_params.get("upload_id") and x_goog_upload_command in ["upload", "upload, finalize"]: + logger.debug("This is an upload request, not initialization. Redirecting to handle_upload.") + return await handle_upload( + upload_path="v1beta/files", + request=request, + key=request.query_params.get("key"), + auth_token=auth_token + ) + + try: + # 使用认证 token 作为 user_token + user_token = auth_token + # 获取请求体 + body = await request.body() + + # 构建请求主机 URL + request_host = f"{request.url.scheme}://{request.url.netloc}" + logger.info(f"Request host: {request_host}") + + # 准备请求头 + headers = { + "x-goog-upload-protocol": x_goog_upload_protocol or "resumable", + "x-goog-upload-command": x_goog_upload_command or "start", + } + + if x_goog_upload_header_content_length: + headers["x-goog-upload-header-content-length"] = x_goog_upload_header_content_length + if x_goog_upload_header_content_type: + headers["x-goog-upload-header-content-type"] = x_goog_upload_header_content_type + + # 调用服务 + files_service = await get_files_service() + response_data, response_headers = await files_service.initialize_upload( + headers=headers, + body=body, + user_token=user_token, + request_host=request_host # 傳遞請求主機 + ) + + logger.info(f"Upload initialization response: {response_data}") + logger.info(f"Upload initialization response headers: {response_headers}") + + logger.info(f"Upload initialization response headers: {response_data}") + # 返回响应 + return JSONResponse( + content=response_data, + headers=response_headers + ) + + except HTTPException as e: + logger.error(f"Upload initialization failed: {e.detail}") + return JSONResponse( + content={"error": {"message": e.detail}}, + status_code=e.status_code + ) + except Exception as e: + logger.error(f"Unexpected error in upload initialization: {str(e)}") + return JSONResponse( + content={"error": {"message": "Internal server error"}}, + status_code=500 + ) + + +@router.get("/v1beta/files") +async def list_files( + page_size: int = Query(10, ge=1, le=100, description="每页大小", alias="pageSize"), + page_token: Optional[str] = Query(None, description="分页标记", alias="pageToken"), + auth_token: str = Depends(security_service.verify_key_or_goog_api_key) +) -> ListFilesResponse: + """列出文件""" + logger.debug(f"List files: {page_size=}, {page_token=}, {auth_token=}") + try: + # 使用认证 token 作为 user_token(如果启用用户隔离) + user_token = auth_token if settings.FILES_USER_ISOLATION_ENABLED else None + # 调用服务 + files_service = await get_files_service() + return await files_service.list_files( + page_size=page_size, + page_token=page_token, + user_token=user_token + ) + + except HTTPException as e: + logger.error(f"List files failed: {e.detail}") + return JSONResponse( + content={"error": {"message": e.detail}}, + status_code=e.status_code + ) + except Exception as e: + logger.error(f"Unexpected error in list files: {str(e)}") + return JSONResponse( + content={"error": {"message": "Internal server error"}}, + status_code=500 + ) + + +@router.get("/v1beta/files/{file_id:path}") +async def get_file( + file_id: str, + auth_token: str = Depends(security_service.verify_key_or_goog_api_key) +) -> FileMetadata: + """获取文件信息""" + logger.debug(f"Get file request: {file_id=}, {auth_token=}") + try: + # 使用认证 token 作为 user_token + user_token = auth_token + # 调用服务 + files_service = await get_files_service() + return await files_service.get_file(f"files/{file_id}", user_token) + + except HTTPException as e: + logger.error(f"Get file failed: {e.detail}") + return JSONResponse( + content={"error": {"message": e.detail}}, + status_code=e.status_code + ) + except Exception as e: + logger.error(f"Unexpected error in get file: {str(e)}") + return JSONResponse( + content={"error": {"message": "Internal server error"}}, + status_code=500 + ) + + +@router.delete("/v1beta/files/{file_id:path}") +async def delete_file( + file_id: str, + auth_token: str = Depends(security_service.verify_key_or_goog_api_key) +) -> DeleteFileResponse: + """删除文件""" + logger.info(f"Delete file: {file_id=}, {auth_token=}") + try: + # 使用认证 token 作为 user_token + user_token = auth_token + # 调用服务 + files_service = await get_files_service() + success = await files_service.delete_file(f"files/{file_id}", user_token) + + return DeleteFileResponse( + success=success, + message="File deleted successfully" if success else "Failed to delete file" + ) + + except HTTPException as e: + logger.error(f"Delete file failed: {e.detail}") + return JSONResponse( + content={"error": {"message": e.detail}}, + status_code=e.status_code + ) + except Exception as e: + logger.error(f"Unexpected error in delete file: {str(e)}") + return JSONResponse( + content={"error": {"message": "Internal server error"}}, + status_code=500 + ) + + +# 处理上传请求的通配符路由 +@router.api_route("/upload/{upload_path:path}", methods=["GET", "POST", "PUT"]) +async def handle_upload( + upload_path: str, + request: Request, + key: Optional[str] = Query(None), # 從查詢參數獲取 key + auth_token: str = Depends(security_service.verify_key_or_goog_api_key) +): + """处理文件上传请求""" + try: + logger.info(f"Handling upload request: {request.method} {upload_path}, key={key}") + + # 從查詢參數獲取 upload_id + upload_id = request.query_params.get("upload_id") + if not upload_id: + raise HTTPException(status_code=400, detail="Missing upload_id") + + # 從 session 獲取真實的 API key + files_service = await get_files_service() + session_info = await files_service.get_upload_session(upload_id) + if not session_info: + logger.error(f"No session found for upload_id: {upload_id}") + raise HTTPException(status_code=404, detail="Upload session not found") + + real_api_key = session_info["api_key"] + original_upload_url = session_info["upload_url"] + + # 使用真實的 API key 構建完整的 Google 上傳 URL + # 保留原始 URL 的所有參數,但使用真實的 API key + upload_url = original_upload_url + logger.info(f"Using real API key for upload: {real_api_key[:8]}...{real_api_key[-4:]}") + + # 代理上传请求 + upload_handler = get_upload_handler() + return await upload_handler.proxy_upload_request( + request=request, + upload_url=upload_url, + files_service=files_service + ) + + except HTTPException as e: + logger.error(f"Upload handling failed: {e.detail}") + return JSONResponse( + content={"error": {"message": e.detail}}, + status_code=e.status_code + ) + except Exception as e: + logger.error(f"Unexpected error in upload handling: {str(e)}") + return JSONResponse( + content={"error": {"message": "Internal server error"}}, + status_code=500 + ) + + +# 为兼容性添加 /gemini 前缀的路由 +@router.post("/gemini/upload/v1beta/files") +async def gemini_upload_file_init( + request: Request, + auth_token: str = Depends(security_service.verify_key_or_goog_api_key), + x_goog_upload_protocol: Optional[str] = Header(None), + x_goog_upload_command: Optional[str] = Header(None), + x_goog_upload_header_content_length: Optional[str] = Header(None), + x_goog_upload_header_content_type: Optional[str] = Header(None), +): + """初始化文件上传(Gemini 前缀)""" + return await upload_file_init( + request, + auth_token, + x_goog_upload_protocol, + x_goog_upload_command, + x_goog_upload_header_content_length, + x_goog_upload_header_content_type + ) + + +@router.get("/gemini/v1beta/files") +async def gemini_list_files( + page_size: int = Query(10, ge=1, le=100, alias="pageSize"), + page_token: Optional[str] = Query(None, alias="pageToken"), + auth_token: str = Depends(security_service.verify_key_or_goog_api_key) +) -> ListFilesResponse: + """列出文件(Gemini 前缀)""" + return await list_files(page_size, page_token, auth_token) + + +@router.get("/gemini/v1beta/files/{file_id:path}") +async def gemini_get_file( + file_id: str, + auth_token: str = Depends(security_service.verify_key_or_goog_api_key) +) -> FileMetadata: + """获取文件信息(Gemini 前缀)""" + return await get_file(file_id, auth_token) + + +@router.delete("/gemini/v1beta/files/{file_id:path}") +async def gemini_delete_file( + file_id: str, + auth_token: str = Depends(security_service.verify_key_or_goog_api_key) +) -> DeleteFileResponse: + """删除文件(Gemini 前缀)""" + return await delete_file(file_id, auth_token) \ No newline at end of file diff --git a/app/router/routes.py b/app/router/routes.py index 46004d7..bc916f6 100644 --- a/app/router/routes.py +++ b/app/router/routes.py @@ -8,7 +8,7 @@ from fastapi.templating import Jinja2Templates from app.core.security import verify_auth_token from app.log.logger import get_routes_logger -from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes +from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes, files_routes from app.service.key.key_manager import get_key_manager_instance from app.service.stats.stats_service import StatsService @@ -34,6 +34,7 @@ def setup_routers(app: FastAPI) -> None: app.include_router(version_routes.router) app.include_router(openai_compatiable_routes.router) app.include_router(vertex_express_routes.router) + app.include_router(files_routes.router) setup_page_routes(app) diff --git a/app/scheduler/scheduled_tasks.py b/app/scheduler/scheduled_tasks.py index be57a23..b185f1e 100644 --- a/app/scheduler/scheduled_tasks.py +++ b/app/scheduler/scheduled_tasks.py @@ -8,6 +8,7 @@ from app.service.chat.gemini_chat_service import GeminiChatService from app.service.error_log.error_log_service import delete_old_error_logs from app.service.key.key_manager import get_key_manager_instance from app.service.request_log.request_log_service import delete_old_request_logs_task +from app.service.files.files_service import get_files_service logger = Logger.setup_logger("scheduler") @@ -96,6 +97,26 @@ async def check_failed_keys(): ) +async def cleanup_expired_files(): + """ + 定时清理过期的文件记录 + """ + logger.info("Starting scheduled cleanup for expired files...") + try: + files_service = await get_files_service() + deleted_count = await files_service.cleanup_expired_files() + + if deleted_count > 0: + logger.info(f"Successfully cleaned up {deleted_count} expired files.") + else: + logger.info("No expired files to clean up.") + + except Exception as e: + logger.error( + f"An error occurred during the scheduled file cleanup: {str(e)}", exc_info=True + ) + + def setup_scheduler(): """设置并启动 APScheduler""" scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区 @@ -134,6 +155,20 @@ def setup_scheduler(): logger.info( f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days." ) + + # 新增:添加文件过期清理的定时任务,每小时执行一次 + if getattr(settings, 'FILES_CLEANUP_ENABLED', True): + cleanup_interval = getattr(settings, 'FILES_CLEANUP_INTERVAL_HOURS', 1) + scheduler.add_job( + cleanup_expired_files, + "interval", + hours=cleanup_interval, + id="cleanup_expired_files_job", + name="Cleanup Expired Files", + ) + logger.info( + f"File cleanup job scheduled to run every {cleanup_interval} hour(s)." + ) scheduler.start() logger.info("Scheduler started with all jobs.") diff --git a/app/service/chat/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py index a5a2746..ca49ca8 100644 --- a/app/service/chat/gemini_chat_service.py +++ b/app/service/chat/gemini_chat_service.py @@ -13,7 +13,7 @@ from app.handler.stream_optimizer import gemini_optimizer from app.log.logger import get_gemini_logger from app.service.client.api_client import GeminiApiClient from app.service.key.key_manager import KeyManager -from app.database.services import add_error_log, add_request_log +from app.database.services import add_error_log, add_request_log, get_file_api_key logger = get_gemini_logger() @@ -27,6 +27,28 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool: return True return False +def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]: + """從內容中提取文件引用""" + file_names = [] + for content in contents: + if "parts" in content: + for part in content["parts"]: + if not isinstance(part, dict) or "fileData" not in part: + continue + file_data = part["fileData"] + if "fileUri" not in file_data: + continue + file_uri = file_data["fileUri"] + # 從 URI 中提取文件名 + # 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id} + match = re.match(r"https://generativelanguage.googleapis.com/v1beta/(files/.*)", file_uri) + if not match: + logger.warning(f"Invalid file URI: {file_uri}") + continue + file_id = match.group(1) + file_names.append(file_id) + logger.info(f"Found file reference: {file_id}") + return file_names def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: """构建工具""" @@ -172,6 +194,17 @@ class GeminiChatService: self, model: str, request: GeminiRequest, api_key: str ) -> Dict[str, Any]: """生成内容""" + # 檢查並獲取文件專用的 API key(如果有文件) + file_names = _extract_file_references(request.model_dump().get("contents", [])) + if file_names: + logger.info(f"Request contains file references: {file_names}") + file_api_key = await get_file_api_key(file_names[0]) + if file_api_key: + logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}") + api_key = file_api_key # 使用文件的 API key + else: + logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}") + payload = _build_payload(model, request) start_time = time.perf_counter() request_datetime = datetime.datetime.now() @@ -267,6 +300,17 @@ class GeminiChatService: self, model: str, request: GeminiRequest, api_key: str ) -> AsyncGenerator[str, None]: """流式生成内容""" + # 檢查並獲取文件專用的 API key(如果有文件) + file_names = _extract_file_references(request.model_dump().get("contents", [])) + if file_names: + logger.info(f"Request contains file references: {file_names}") + file_api_key = await get_file_api_key(file_names[0]) + if file_api_key: + logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}") + api_key = file_api_key # 使用文件的 API key + else: + logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}") + retries = 0 max_retries = settings.MAX_RETRIES payload = _build_payload(model, request) diff --git a/app/service/files/__init__.py b/app/service/files/__init__.py new file mode 100644 index 0000000..9770023 --- /dev/null +++ b/app/service/files/__init__.py @@ -0,0 +1 @@ +# Intentionally empty __init__.py file \ No newline at end of file diff --git a/app/service/files/file_upload_handler.py b/app/service/files/file_upload_handler.py new file mode 100644 index 0000000..80d0f2d --- /dev/null +++ b/app/service/files/file_upload_handler.py @@ -0,0 +1,248 @@ +""" +文件上传处理器 +处理 Google 的可恢复上传协议 +""" +import hashlib +from typing import Dict, Any, Optional +from datetime import datetime, timezone, timedelta + +from httpx import AsyncClient +from fastapi import Request, Response, HTTPException +from fastapi.responses import StreamingResponse + +from app.database import services as db_services +from app.database.models import FileState +from app.log.logger import get_files_logger + +logger = get_files_logger() + + +class FileUploadHandler: + """处理文件分块上传""" + + def __init__(self): + self.chunk_size = 8 * 1024 * 1024 # 8MB + + async def handle_upload_chunk( + self, + upload_url: str, + request: Request, + files_service=None # 添加 files_service 參數 + ) -> Response: + """ + 处理上传分块 + + Args: + upload_url: 上传 URL + request: FastAPI 请求对象 + files_service: 文件服務實例 + + Returns: + Response: 响应对象 + """ + try: + # 获取请求头 + headers = {} + + # 复制必要的上传头 + upload_headers = [ + "x-goog-upload-command", + "x-goog-upload-offset", + "content-type", + "content-length" + ] + + for header in upload_headers: + if header in request.headers: + # 转换为正确的格式 + key = "-".join(word.capitalize() for word in header.split("-")) + headers[key] = request.headers[header] + + # 读取请求体 + body = await request.body() + + # 检查是否是最后一块 + is_final = "finalize" in headers.get("X-Goog-Upload-Command", "") + logger.debug(f"Upload command: {headers.get('X-Goog-Upload-Command', '')}, is_final: {is_final}") + + # 转发到真实的上传 URL + async with AsyncClient() as client: + response = await client.post( + upload_url, + headers=headers, + content=body, + timeout=300.0 # 5分钟超时 + ) + + if response.status_code not in [200, 201, 308]: + logger.error(f"Upload chunk failed: {response.status_code} - {response.text}") + raise HTTPException(status_code=response.status_code, detail="Upload failed") + + # 如果是最后一块,更新文件状态 + if is_final and response.status_code in [200, 201]: + logger.debug(f"Upload finalized with status {response.status_code}") + try: + # 解析響應獲取文件信息 + response_data = response.json() + logger.debug(f"Upload complete response data: {response_data}") + file_data = response_data.get("file", {}) + + # 獲取真實的文件名 + real_file_name = file_data.get("name") + logger.debug(f"Upload response: {response_data}") + if real_file_name and files_service: + logger.info(f"Upload completed, file name: {real_file_name}") + + # 從會話中獲取信息 + session_info = await files_service.get_upload_session(upload_url) + logger.debug(f"Retrieved session info for {upload_url}: {session_info}") + if session_info: + # 創建文件記錄 + now = datetime.now(timezone.utc) + expiration_time = now + timedelta(hours=48) + + # 處理過期時間格式(Google 可能返回納秒級精度) + expiration_time_str = file_data.get("expirationTime", expiration_time.isoformat() + "Z") + # 處理納秒格式:2025-07-11T02:02:52.531916141Z -> 2025-07-11T02:02:52.531916Z + if expiration_time_str.endswith("Z"): + # 移除 Z + expiration_time_str = expiration_time_str[:-1] + # 如果有納秒(超過6位小數),截斷到微秒 + if "." in expiration_time_str: + date_part, frac_part = expiration_time_str.rsplit(".", 1) + if len(frac_part) > 6: + frac_part = frac_part[:6] + expiration_time_str = f"{date_part}.{frac_part}" + # 添加時區 + expiration_time_str += "+00:00" + + # 獲取文件狀態(Google 可能返回 PROCESSING) + file_state = file_data.get("state", "PROCESSING") + logger.debug(f"File state from Google: {file_state}") + + # 將字符串狀態轉換為枚舉 + if file_state == "ACTIVE": + state_enum = FileState.ACTIVE + elif file_state == "PROCESSING": + state_enum = FileState.PROCESSING + elif file_state == "FAILED": + state_enum = FileState.FAILED + else: + logger.warning(f"Unknown file state: {file_state}, defaulting to PROCESSING") + state_enum = FileState.PROCESSING + + await db_services.create_file_record( + name=real_file_name, + mime_type=file_data.get("mimeType", session_info["mime_type"]), + size_bytes=int(file_data.get("sizeBytes", session_info["size_bytes"])), + api_key=session_info["api_key"], + uri=file_data.get("uri", f"https://generativelanguage.googleapis.com/v1beta/{real_file_name}"), + create_time=now, + update_time=now, + expiration_time=datetime.fromisoformat(expiration_time_str), + state=state_enum, + display_name=file_data.get("displayName", session_info.get("display_name", "")), + sha256_hash=file_data.get("sha256Hash"), + user_token=session_info["user_token"] + ) + logger.info(f"Created file record: name={real_file_name}, api_key={session_info['api_key'][:8]}...{session_info['api_key'][-4:]}") + else: + logger.warning(f"No upload session found for URL: {upload_url}") + else: + logger.warning(f"Missing real_file_name or files_service: real_file_name={real_file_name}, files_service={files_service}") + + # 返回完整的文件信息 + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers) + ) + except Exception as e: + logger.error(f"Failed to create file record: {str(e)}", exc_info=True) + else: + logger.debug(f"Upload chunk processed: is_final={is_final}, status={response.status_code}") + + # 返回响应 + response_headers = dict(response.headers) + + # 确保包含必要的头 + if response.status_code == 308: # Resume Incomplete + if "x-goog-upload-status" not in response_headers: + response_headers["x-goog-upload-status"] = "active" + + return Response( + content=response.content, + status_code=response.status_code, + headers=response_headers + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to handle upload chunk: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") + + async def proxy_upload_request( + self, + request: Request, + upload_url: str, + files_service=None + ) -> Response: + """ + 代理上传请求 + + Args: + request: FastAPI 请求对象 + upload_url: 目标上传 URL + files_service: 文件服務實例 + + Returns: + Response: 代理响应 + """ + logger.debug(f"Proxy upload request: {request.method}, {upload_url}") + try: + # 如果是 GET 请求,返回上传状态 + if request.method == "GET": + return await self._get_upload_status(upload_url) + + # 处理 POST/PUT 请求 + return await self.handle_upload_chunk(upload_url, request, files_service) + + except Exception as e: + logger.error(f"Failed to proxy upload request: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") + + async def _get_upload_status(self, upload_url: str) -> Response: + """ + 获取上传状态 + + Args: + upload_url: 上传 URL + + Returns: + Response: 状态响应 + """ + try: + async with AsyncClient() as client: + response = await client.get(upload_url) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers) + ) + except Exception as e: + logger.error(f"Failed to get upload status: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") + + +# 单例实例 +_upload_handler_instance: Optional[FileUploadHandler] = None + + +def get_upload_handler() -> FileUploadHandler: + """获取上传处理器单例实例""" + global _upload_handler_instance + if _upload_handler_instance is None: + _upload_handler_instance = FileUploadHandler() + return _upload_handler_instance \ No newline at end of file diff --git a/app/service/files/files_service.py b/app/service/files/files_service.py new file mode 100644 index 0000000..2a07af7 --- /dev/null +++ b/app/service/files/files_service.py @@ -0,0 +1,501 @@ +""" +文件管理服务 +""" +import json +import hashlib +import uuid +from datetime import datetime, timedelta, timezone +from typing import Optional, Dict, Any, Tuple, List +from httpx import AsyncClient, Headers +from collections import defaultdict +import asyncio + +from app.config.config import settings +from app.database import services as db_services +from app.database.models import FileRecord, FileState +from app.domain.file_models import FileMetadata, ListFilesResponse +from fastapi import HTTPException +from app.log.logger import get_files_logger +from app.service.client.api_client import GeminiApiClient +from app.service.key.key_manager import get_key_manager_instance + +logger = get_files_logger() + +# 全局上傳會話存儲 +_upload_sessions: Dict[str, Dict[str, Any]] = {} +_upload_sessions_lock = asyncio.Lock() + + +class FilesService: + """文件管理服务类""" + + def __init__(self): + self.api_client = GeminiApiClient(base_url=settings.BASE_URL) + self.key_manager = None + + async def _get_key_manager(self): + """获取 KeyManager 实例""" + if not self.key_manager: + self.key_manager = await get_key_manager_instance( + settings.API_KEYS, + settings.VERTEX_API_KEYS + ) + return self.key_manager + + async def initialize_upload( + self, + headers: Dict[str, str], + body: Optional[bytes], + user_token: str, + request_host: str = None # 添加請求主機參數 + ) -> Tuple[Dict[str, Any], Dict[str, str]]: + """ + 初始化文件上传 + + Args: + headers: 请求头 + body: 请求体 + user_token: 用户令牌 + + Returns: + Tuple[Dict[str, Any], Dict[str, str]]: (响应体, 响应头) + """ + try: + # 获取可用的 API key + key_manager = await self._get_key_manager() + api_key = await key_manager.get_next_key() + + if not api_key: + raise HTTPException(status_code=503, detail="No available API keys") + + # 转发请求到真实的 Gemini API + async with AsyncClient() as client: + # 准备请求头 + forward_headers = { + "X-Goog-Upload-Protocol": headers.get("x-goog-upload-protocol", "resumable"), + "X-Goog-Upload-Command": headers.get("x-goog-upload-command", "start"), + "Content-Type": headers.get("content-type", "application/json"), + } + + # 添加其他必要的头 + if "x-goog-upload-header-content-length" in headers: + forward_headers["X-Goog-Upload-Header-Content-Length"] = headers["x-goog-upload-header-content-length"] + if "x-goog-upload-header-content-type" in headers: + forward_headers["X-Goog-Upload-Header-Content-Type"] = headers["x-goog-upload-header-content-type"] + + # 发送请求 + response = await client.post( + "https://generativelanguage.googleapis.com/upload/v1beta/files", + headers=forward_headers, + content=body, + params={"key": api_key} + ) + + if response.status_code != 200: + logger.error(f"Upload initialization failed: {response.status_code} - {response.text}") + raise HTTPException(status_code=response.status_code, detail="Upload initialization failed") + + # 获取上传 URL + upload_url = response.headers.get("x-goog-upload-url") + if not upload_url: + raise HTTPException(status_code=500, detail="No upload URL in response") + + logger.info(f"Original upload URL from Google: {upload_url}") + + + # 儲存上傳資訊到 headers 中,供後續使用 + # 不在這裡創建數據庫記錄,等到上傳完成後再創建 + logger.info(f"Upload initialized with API key: {api_key[:8]}...{api_key[-4:]}") + + # 解析响应 - 初始化响应可能是空的 + response_data = {} + + # 從請求體中解析文件信息(如果有) + display_name = "" + if body: + try: + request_data = json.loads(body) + display_name = request_data.get("displayName", "") + except: + pass + # 從 upload URL 中提取 upload_id + import urllib.parse + parsed_url = urllib.parse.urlparse(upload_url) + query_params = urllib.parse.parse_qs(parsed_url.query) + upload_id = query_params.get('upload_id', [None])[0] + + if upload_id: + # 儲存上傳會話信息,使用 upload_id 作為 key + async with _upload_sessions_lock: + _upload_sessions[upload_id] = { + "api_key": api_key, + "user_token": user_token, + "display_name": display_name, + "mime_type": headers.get("x-goog-upload-header-content-type", "application/octet-stream"), + "size_bytes": int(headers.get("x-goog-upload-header-content-length", "0")), + "created_at": datetime.now(timezone.utc), + "upload_url": upload_url + } + logger.info(f"Stored upload session for upload_id={upload_id}: api_key={api_key[:8]}...{api_key[-4:]}") + logger.debug(f"Total active sessions: {len(_upload_sessions)}") + else: + logger.warning(f"No upload_id found in upload URL: {upload_url}") + + # 定期清理過期的會話(超過1小時) + asyncio.create_task(self._cleanup_expired_sessions()) + + # 替換 Google 的 URL 為我們的代理 URL + proxy_upload_url = upload_url + if request_host: + # 原始: https://generativelanguage.googleapis.com/upload/v1beta/files?key=AIzaSyDc...&upload_id=xxx&upload_protocol=resumable + # 替換為: http://request-host/upload/v1beta/files?key=sk-123456&upload_id=xxx&upload_protocol=resumable + + # 先替換域名 + proxy_upload_url = upload_url.replace( + "https://generativelanguage.googleapis.com", + request_host.rstrip('/') + ) + + # 再替換 key 參數 + import re + # 匹配 key=xxx 參數 + key_pattern = r'(\?|&)key=([^&]+)' + match = re.search(key_pattern, proxy_upload_url) + if match: + # 替換為我們的 token + proxy_upload_url = proxy_upload_url.replace( + f"{match.group(1)}key={match.group(2)}", + f"{match.group(1)}key={user_token}" + ) + + logger.info(f"Replaced upload URL: {upload_url} -> {proxy_upload_url}") + + return response_data, { + "X-Goog-Upload-URL": proxy_upload_url, + "X-Goog-Upload-Status": "active" + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to initialize upload: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") + + async def _cleanup_expired_sessions(self): + """清理過期的上傳會話""" + try: + async with _upload_sessions_lock: + now = datetime.now(timezone.utc) + expired_keys = [] + for key, session in _upload_sessions.items(): + if now - session["created_at"] > timedelta(hours=1): + expired_keys.append(key) + + for key in expired_keys: + del _upload_sessions[key] + + if expired_keys: + logger.info(f"Cleaned up {len(expired_keys)} expired upload sessions") + except Exception as e: + logger.error(f"Error cleaning up upload sessions: {str(e)}") + + async def get_upload_session(self, key: str) -> Optional[Dict[str, Any]]: + """獲取上傳會話信息(支持 upload_id 或完整 URL)""" + async with _upload_sessions_lock: + # 先嘗試直接查找 + session = _upload_sessions.get(key) + if session: + logger.debug(f"Found session by direct key {key}") + return session + + # 如果是 URL,嘗試提取 upload_id + if key.startswith("http"): + import urllib.parse + parsed_url = urllib.parse.urlparse(key) + query_params = urllib.parse.parse_qs(parsed_url.query) + upload_id = query_params.get('upload_id', [None])[0] + if upload_id: + session = _upload_sessions.get(upload_id) + if session: + logger.debug(f"Found session by upload_id {upload_id} from URL") + return session + + logger.debug(f"No session found for key: {key}") + return None + + async def get_file(self, file_name: str, user_token: str) -> FileMetadata: + """ + 获取文件信息 + + Args: + file_name: 文件名称 (格式: files/{file_id}) + user_token: 用户令牌 + + Returns: + FileMetadata: 文件元数据 + """ + try: + # 查询文件记录 + file_record = await db_services.get_file_record_by_name(file_name) + + if not file_record: + raise HTTPException(status_code=404, detail="File not found") + + # 检查是否过期 + expiration_time = datetime.fromisoformat(str(file_record["expiration_time"])) + # 如果是 naive datetime,假设为 UTC + if expiration_time.tzinfo is None: + expiration_time = expiration_time.replace(tzinfo=timezone.utc) + if expiration_time <= datetime.now(timezone.utc): + raise HTTPException(status_code=404, detail="File has expired") + + # 使用原始 API key 获取文件信息 + api_key = file_record["api_key"] + + async with AsyncClient() as client: + response = await client.get( + f"https://generativelanguage.googleapis.com/v1beta/{file_name}", + params={"key": api_key} + ) + + if response.status_code != 200: + logger.error(f"Failed to get file: {response.status_code} - {response.text}") + raise HTTPException(status_code=response.status_code, detail="Failed to get file") + + file_data = response.json() + + # 檢查並更新文件狀態 + google_state = file_data.get("state", "PROCESSING") + if google_state != file_record.get("state", "").value if file_record.get("state") else None: + logger.info(f"File state changed from {file_record.get('state')} to {google_state}") + # 更新數據庫中的狀態 + if google_state == "ACTIVE": + await db_services.update_file_record_state( + file_name=file_name, + state=FileState.ACTIVE, + update_time=datetime.now(timezone.utc) + ) + elif google_state == "FAILED": + await db_services.update_file_record_state( + file_name=file_name, + state=FileState.FAILED, + update_time=datetime.now(timezone.utc) + ) + + # 构建响应 + return FileMetadata( + name=file_data["name"], + displayName=file_data.get("displayName"), + mimeType=file_data["mimeType"], + sizeBytes=str(file_data["sizeBytes"]), + createTime=file_data["createTime"], + updateTime=file_data["updateTime"], + expirationTime=file_data["expirationTime"], + sha256Hash=file_data.get("sha256Hash"), + uri=file_data["uri"], + state=google_state + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get file {file_name}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") + + async def list_files( + self, + page_size: int = 10, + page_token: Optional[str] = None, + user_token: Optional[str] = None + ) -> ListFilesResponse: + """ + 列出文件 + + Args: + page_size: 每页大小 + page_token: 分页标记 + user_token: 用户令牌(可选,如果提供则只返回该用户的文件) + + Returns: + ListFilesResponse: 文件列表响应 + """ + try: + logger.debug(f"list_files called with page_size={page_size}, page_token={page_token}") + + # 从数据库获取文件列表 + files, next_page_token = await db_services.list_file_records( + user_token=user_token, + page_size=page_size, + page_token=page_token + ) + + logger.debug(f"Database returned {len(files)} files, next_page_token={next_page_token}") + + # 转换为响应格式 + file_list = [] + for file_record in files: + file_list.append(FileMetadata( + name=file_record["name"], + displayName=file_record.get("display_name"), + mimeType=file_record["mime_type"], + sizeBytes=str(file_record["size_bytes"]), + createTime=file_record["create_time"].isoformat() + "Z", + updateTime=file_record["update_time"].isoformat() + "Z", + expirationTime=file_record["expiration_time"].isoformat() + "Z", + sha256Hash=file_record.get("sha256_hash"), + uri=file_record["uri"], + state=file_record["state"].value if file_record.get("state") else "ACTIVE" + )) + + response = ListFilesResponse( + files=file_list, + nextPageToken=next_page_token + ) + + logger.debug(f"Returning response with {len(response.files)} files, nextPageToken={response.nextPageToken}") + + return response + + except Exception as e: + logger.error(f"Failed to list files: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") + + async def delete_file(self, file_name: str, user_token: str) -> bool: + """ + 删除文件 + + Args: + file_name: 文件名称 + user_token: 用户令牌 + + Returns: + bool: 是否删除成功 + """ + try: + # 查询文件记录 + file_record = await db_services.get_file_record_by_name(file_name) + + if not file_record: + raise HTTPException(status_code=404, detail="File not found") + + # 使用原始 API key 删除文件 + api_key = file_record["api_key"] + + async with AsyncClient() as client: + response = await client.delete( + f"https://generativelanguage.googleapis.com/v1beta/{file_name}", + params={"key": api_key} + ) + + if response.status_code not in [200, 204]: + logger.error(f"Failed to delete file: {response.status_code} - {response.text}") + # 如果 API 删除失败,但文件已过期,仍然删除数据库记录 + expiration_time = datetime.fromisoformat(str(file_record["expiration_time"])) + if expiration_time.tzinfo is None: + expiration_time = expiration_time.replace(tzinfo=timezone.utc) + if expiration_time <= datetime.now(timezone.utc): + await db_services.delete_file_record(file_name) + return True + raise HTTPException(status_code=response.status_code, detail="Failed to delete file") + + # 删除数据库记录 + await db_services.delete_file_record(file_name) + return True + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete file {file_name}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") + + async def check_file_state(self, file_name: str, api_key: str) -> str: + """ + 檢查並更新文件狀態 + + Args: + file_name: 文件名稱 + api_key: API密鑰 + + Returns: + str: 當前狀態 + """ + try: + async with AsyncClient() as client: + response = await client.get( + f"https://generativelanguage.googleapis.com/v1beta/{file_name}", + params={"key": api_key} + ) + + if response.status_code != 200: + logger.error(f"Failed to check file state: {response.status_code}") + return "UNKNOWN" + + file_data = response.json() + google_state = file_data.get("state", "PROCESSING") + + # 更新數據庫狀態 + if google_state == "ACTIVE": + await db_services.update_file_record_state( + file_name=file_name, + state=FileState.ACTIVE, + update_time=datetime.now(timezone.utc) + ) + elif google_state == "FAILED": + await db_services.update_file_record_state( + file_name=file_name, + state=FileState.FAILED, + update_time=datetime.now(timezone.utc) + ) + + return google_state + + except Exception as e: + logger.error(f"Failed to check file state: {str(e)}") + return "UNKNOWN" + + async def cleanup_expired_files(self) -> int: + """ + 清理过期文件 + + Returns: + int: 清理的文件数量 + """ + try: + # 获取过期文件 + expired_files = await db_services.delete_expired_file_records() + + if not expired_files: + return 0 + + # 尝试从 Gemini API 删除文件 + for file_record in expired_files: + try: + api_key = file_record["api_key"] + file_name = file_record["name"] + + async with AsyncClient() as client: + await client.delete( + f"https://generativelanguage.googleapis.com/v1beta/{file_name}", + params={"key": api_key} + ) + except Exception as e: + # 记录错误但继续处理其他文件 + logger.error(f"Failed to delete file {file_record['name']} from API: {str(e)}") + + return len(expired_files) + + except Exception as e: + logger.error(f"Failed to cleanup expired files: {str(e)}") + return 0 + + +# 单例实例 +_files_service_instance: Optional[FilesService] = None + + +async def get_files_service() -> FilesService: + """获取文件服务单例实例""" + global _files_service_instance + if _files_service_instance is None: + _files_service_instance = FilesService() + return _files_service_instance \ No newline at end of file