mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-06-02 22:30:21 +08:00
feat: Add Files API support with upload, list, get and delete operations
- Implement complete Files API compatible with Gemini API format - Support resumable file uploads with chunked transfer (tested with 15MB video) - Create file management service with database tracking - Add file domain models and API request/response objects - Implement file routes with proper authentication - Use fixed API key for Files API requests (due to Google API restrictions) - Support file state management (PROCESSING, ACTIVE, FAILED) - Add scheduled task for automatic expired file cleanup - Integrate seamlessly with existing key management and load balancing
This commit is contained in:
@@ -7,7 +7,7 @@ from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.connection import engine, Base
|
||||
from app.database.models import Settings
|
||||
from app.database.models import Settings, FileRecord
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
数据库模型模块
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum
|
||||
import enum
|
||||
|
||||
from app.database.connection import Base
|
||||
|
||||
@@ -60,3 +61,69 @@ class RequestLog(Base):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
||||
|
||||
|
||||
class FileState(enum.Enum):
|
||||
"""文件状态枚举"""
|
||||
PROCESSING = "PROCESSING"
|
||||
ACTIVE = "ACTIVE"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
"""
|
||||
文件记录表,用于存储上传到 Gemini 的文件信息
|
||||
"""
|
||||
__tablename__ = "t_file_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# 文件基本信息
|
||||
name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}")
|
||||
display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名")
|
||||
mime_type = Column(String(100), nullable=False, comment="MIME 类型")
|
||||
size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)")
|
||||
sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值")
|
||||
|
||||
# 状态信息
|
||||
state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态")
|
||||
|
||||
# 时间戳
|
||||
create_time = Column(DateTime, nullable=False, comment="创建时间")
|
||||
update_time = Column(DateTime, nullable=False, comment="更新时间")
|
||||
expiration_time = Column(DateTime, nullable=False, comment="过期时间")
|
||||
|
||||
# API 相关
|
||||
uri = Column(String(500), nullable=False, comment="文件访问 URI")
|
||||
api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key")
|
||||
upload_url = Column(Text, nullable=True, comment="临时上传 URL(用于分块上传)")
|
||||
|
||||
# 额外信息
|
||||
user_token = Column(String(100), nullable=True, comment="上传用户的 token")
|
||||
upload_completed = Column(DateTime, nullable=True, comment="上传完成时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileRecord(name='{self.name}', state='{self.state.value if self.state else 'None'}', api_key='{self.api_key[:8]}...')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典格式,用于 API 响应"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"displayName": self.display_name,
|
||||
"mimeType": self.mime_type,
|
||||
"sizeBytes": str(self.size_bytes),
|
||||
"createTime": self.create_time.isoformat() + "Z",
|
||||
"updateTime": self.update_time.isoformat() + "Z",
|
||||
"expirationTime": self.expiration_time.isoformat() + "Z",
|
||||
"sha256Hash": self.sha256_hash,
|
||||
"uri": self.uri,
|
||||
"state": self.state.value if self.state else "PROCESSING"
|
||||
}
|
||||
|
||||
def is_expired(self):
|
||||
"""检查文件是否已过期"""
|
||||
# 确保比较时都是 timezone-aware
|
||||
expiration_time = self.expiration_time
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc)
|
||||
return datetime.datetime.now(datetime.timezone.utc) > expiration_time
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
数据库服务模块
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlalchemy import func, desc, asc, select, insert, update, delete
|
||||
import json
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings, ErrorLog, RequestLog
|
||||
from app.database.models import Settings, ErrorLog, RequestLog, FileRecord, FileState
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
@@ -427,3 +427,264 @@ async def add_request_log(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add request log: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 文件记录相关函数 ====================
|
||||
|
||||
async def create_file_record(
|
||||
name: str,
|
||||
mime_type: str,
|
||||
size_bytes: int,
|
||||
api_key: str,
|
||||
uri: str,
|
||||
create_time: datetime,
|
||||
update_time: datetime,
|
||||
expiration_time: datetime,
|
||||
state: FileState = FileState.PROCESSING,
|
||||
display_name: Optional[str] = None,
|
||||
sha256_hash: Optional[str] = None,
|
||||
upload_url: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
mime_type: MIME 类型
|
||||
size_bytes: 文件大小(字节)
|
||||
api_key: 上传时使用的 API Key
|
||||
uri: 文件访问 URI
|
||||
create_time: 创建时间
|
||||
update_time: 更新时间
|
||||
expiration_time: 过期时间
|
||||
display_name: 显示名称
|
||||
sha256_hash: SHA256 哈希值
|
||||
upload_url: 临时上传 URL
|
||||
user_token: 上传用户的 token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 创建的文件记录
|
||||
"""
|
||||
try:
|
||||
query = insert(FileRecord).values(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
mime_type=mime_type,
|
||||
size_bytes=size_bytes,
|
||||
sha256_hash=sha256_hash,
|
||||
state=state,
|
||||
create_time=create_time,
|
||||
update_time=update_time,
|
||||
expiration_time=expiration_time,
|
||||
uri=uri,
|
||||
api_key=api_key,
|
||||
upload_url=upload_url,
|
||||
user_token=user_token
|
||||
)
|
||||
await database.execute(query)
|
||||
|
||||
# 返回创建的记录
|
||||
return await get_file_record_by_name(name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据文件名获取文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord).where(FileRecord.name == name)
|
||||
result = await database.fetch_one(query)
|
||||
return dict(result) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file record by name {name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def update_file_record_state(
|
||||
file_name: str,
|
||||
state: FileState,
|
||||
update_time: Optional[datetime] = None,
|
||||
upload_completed: Optional[datetime] = None,
|
||||
sha256_hash: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新文件记录状态
|
||||
|
||||
Args:
|
||||
file_name: 文件名
|
||||
state: 新状态
|
||||
update_time: 更新时间
|
||||
upload_completed: 上传完成时间
|
||||
sha256_hash: SHA256 哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
values = {"state": state}
|
||||
if update_time:
|
||||
values["update_time"] = update_time
|
||||
if upload_completed:
|
||||
values["upload_completed"] = upload_completed
|
||||
if sha256_hash:
|
||||
values["sha256_hash"] = sha256_hash
|
||||
|
||||
query = update(FileRecord).where(FileRecord.name == file_name).values(**values)
|
||||
result = await database.execute(query)
|
||||
|
||||
if result:
|
||||
logger.info(f"Updated file record state for {file_name} to {state}")
|
||||
return True
|
||||
|
||||
logger.warning(f"File record not found for update: {file_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file record state: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def list_file_records(
|
||||
user_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
列出文件记录
|
||||
|
||||
Args:
|
||||
user_token: 用户 token(如果提供,只返回该用户的文件)
|
||||
api_key: API Key(如果提供,只返回使用该 key 的文件)
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记(偏移量)
|
||||
|
||||
Returns:
|
||||
tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_file_records called with page_size={page_size}, page_token={page_token}")
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
if user_token:
|
||||
query = query.where(FileRecord.user_token == user_token)
|
||||
if api_key:
|
||||
query = query.where(FileRecord.api_key == api_key)
|
||||
|
||||
# 使用偏移量进行分页
|
||||
offset = 0
|
||||
if page_token:
|
||||
try:
|
||||
offset = int(page_token)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid page token: {page_token}")
|
||||
offset = 0
|
||||
|
||||
# 按ID升序排列,使用 OFFSET 和 LIMIT
|
||||
query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
logger.debug(f"Query returned {len(results)} records")
|
||||
if results:
|
||||
logger.debug(f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}")
|
||||
|
||||
# 处理分页
|
||||
has_next = len(results) > page_size
|
||||
if has_next:
|
||||
results = results[:page_size]
|
||||
# 下一页的偏移量是当前偏移量加上本页返回的记录数
|
||||
next_offset = offset + page_size
|
||||
next_page_token = str(next_offset)
|
||||
logger.debug(f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}")
|
||||
else:
|
||||
next_page_token = None
|
||||
logger.debug(f"No next page, returning {len(results)} results")
|
||||
|
||||
return [dict(row) for row in results], next_page_token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def delete_file_record(name: str) -> bool:
|
||||
"""
|
||||
删除文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
query = delete(FileRecord).where(FileRecord.name == name)
|
||||
await database.execute(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file record: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_expired_file_records() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
删除已过期的文件记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 删除的记录列表
|
||||
"""
|
||||
try:
|
||||
# 先获取要删除的记录
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(datetime.timezone.utc)
|
||||
)
|
||||
expired_records = await database.fetch_all(query)
|
||||
|
||||
if not expired_records:
|
||||
return []
|
||||
|
||||
# 执行删除
|
||||
delete_query = delete(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(datetime.timezone.utc)
|
||||
)
|
||||
await database.execute(delete_query)
|
||||
|
||||
logger.info(f"Deleted {len(expired_records)} expired file records")
|
||||
return [dict(record) for record in expired_records]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete expired file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_api_key(name: str) -> Optional[str]:
|
||||
"""
|
||||
获取文件对应的 API Key
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: API Key,如果文件不存在或已过期则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord.api_key).where(
|
||||
(FileRecord.name == name) &
|
||||
(FileRecord.expiration_time > datetime.now(timezone.utc))
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
return result["api_key"] if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file API key: {str(e)}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user