mirror of
https://github.com/DullJZ/s3-balance.git
synced 2026-06-30 23:51:26 +08:00
548 lines
16 KiB
Go
548 lines
16 KiB
Go
package storage
|
||
|
||
import (
|
||
"fmt"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Service 存储服务(管理对象元数据)
|
||
type Service struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewService 创建新的存储服务
|
||
func NewService(db *gorm.DB) *Service {
|
||
return &Service{
|
||
db: db,
|
||
}
|
||
}
|
||
|
||
// RecordObject 记录对象信息
|
||
func (s *Service) RecordObject(key, bucketName string, size int64, metadata map[string]string) error {
|
||
// 首先检查是否存在已删除的同名对象
|
||
var deletedObj Object
|
||
if err := s.db.Unscoped().Where("`key` = ?", key).Where("`deleted_at` IS NOT NULL").First(&deletedObj).Error; err == nil {
|
||
// 存在已删除的同名对象,永久删除它
|
||
if err := s.db.Unscoped().Delete(&deletedObj).Error; err != nil {
|
||
return fmt.Errorf("failed to permanently delete soft-deleted object: %w", err)
|
||
}
|
||
}
|
||
|
||
obj := &Object{
|
||
Key: key,
|
||
BucketName: bucketName,
|
||
Size: size,
|
||
}
|
||
|
||
if len(metadata) > 0 {
|
||
obj.Metadata = make(JSON)
|
||
for k, v := range metadata {
|
||
obj.Metadata[k] = v
|
||
}
|
||
} else {
|
||
obj.Metadata = make(JSON)
|
||
}
|
||
|
||
// 使用 Upsert(更新或插入)
|
||
result := s.db.Where("`key` = ?", key).Where("`deleted_at` IS NULL").FirstOrCreate(&obj)
|
||
if result.Error != nil {
|
||
return fmt.Errorf("failed to record object: %w", result.Error)
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
// 对象已存在,更新它
|
||
updates := map[string]interface{}{
|
||
"bucket_name": bucketName,
|
||
"size": size,
|
||
"metadata": obj.Metadata,
|
||
"updated_at": time.Now(),
|
||
}
|
||
if err := s.db.Model(&Object{}).Where("`key` = ?", key).Updates(updates).Error; err != nil {
|
||
return fmt.Errorf("failed to update object: %w", err)
|
||
}
|
||
}
|
||
|
||
// 更新存储桶统计
|
||
s.updateBucketStats(bucketName)
|
||
|
||
return nil
|
||
}
|
||
|
||
// FindObjectBucket 查找对象所在的存储桶
|
||
func (s *Service) FindObjectBucket(key string) (string, error) {
|
||
var obj Object
|
||
if err := s.db.Where("`key` = ?", key).First(&obj).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return "", fmt.Errorf("object not found: %s", key)
|
||
}
|
||
return "", fmt.Errorf("failed to find object: %w", err)
|
||
}
|
||
return obj.BucketName, nil
|
||
}
|
||
|
||
// GetObjectInfo 获取对象信息
|
||
func (s *Service) GetObjectInfo(key string) (*Object, error) {
|
||
var obj Object
|
||
if err := s.db.Where("`key` = ?", key).First(&obj).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil, fmt.Errorf("object not found: %s", key)
|
||
}
|
||
return nil, fmt.Errorf("failed to get object info: %w", err)
|
||
}
|
||
return &obj, nil
|
||
}
|
||
|
||
// DeleteObject 删除对象记录(软删除)
|
||
func (s *Service) DeleteObject(key string) error {
|
||
var obj Object
|
||
if err := s.db.Where("`key` = ?", key).First(&obj).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return fmt.Errorf("object not found: %s", key)
|
||
}
|
||
return fmt.Errorf("failed to find object: %w", err)
|
||
}
|
||
|
||
bucketName := obj.BucketName
|
||
|
||
// 软删除
|
||
if err := s.db.Delete(&obj).Error; err != nil {
|
||
return fmt.Errorf("failed to delete object: %w", err)
|
||
}
|
||
|
||
// 更新存储桶统计
|
||
s.updateBucketStats(bucketName)
|
||
|
||
return nil
|
||
}
|
||
|
||
// ListObjects 列出对象(支持S3兼容的参数)
|
||
func (s *Service) ListObjects(bucketName, prefix, marker string, maxKeys int) ([]*Object, error) {
|
||
var objects []*Object
|
||
query := s.db.Model(&Object{})
|
||
|
||
// 按bucket过滤
|
||
if bucketName != "" {
|
||
query = query.Where("bucket_name = ?", bucketName)
|
||
}
|
||
|
||
// 前缀过滤
|
||
if prefix != "" {
|
||
query = query.Where("`key` LIKE ?", prefix+"%")
|
||
}
|
||
|
||
// Marker分页
|
||
if marker != "" {
|
||
query = query.Where("`key` > ?", marker)
|
||
}
|
||
|
||
// 限制返回数量
|
||
if maxKeys > 0 {
|
||
query = query.Limit(maxKeys)
|
||
}
|
||
|
||
// 按key字母顺序排序(S3标准)
|
||
if err := query.Order("`key` ASC").Find(&objects).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to list objects: %w", err)
|
||
}
|
||
|
||
return objects, nil
|
||
}
|
||
|
||
// GetBucketObjects 获取特定存储桶的所有对象
|
||
func (s *Service) GetBucketObjects(bucketName string) ([]*Object, error) {
|
||
var objects []*Object
|
||
if err := s.db.Where("bucket_name = ?", bucketName).Find(&objects).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to get bucket objects: %w", err)
|
||
}
|
||
return objects, nil
|
||
}
|
||
|
||
// GetTotalSize 获取所有对象的总大小
|
||
func (s *Service) GetTotalSize() (int64, error) {
|
||
var total int64
|
||
if err := s.db.Model(&Object{}).Select("COALESCE(SUM(size), 0)").Scan(&total).Error; err != nil {
|
||
return 0, fmt.Errorf("failed to get total size: %w", err)
|
||
}
|
||
return total, nil
|
||
}
|
||
|
||
// GetBucketSize 获取特定存储桶的总大小
|
||
func (s *Service) GetBucketSize(bucketName string) (int64, error) {
|
||
var total int64
|
||
if err := s.db.Model(&Object{}).
|
||
Where("bucket_name = ?", bucketName).
|
||
Select("COALESCE(SUM(size), 0)").
|
||
Scan(&total).Error; err != nil {
|
||
return 0, fmt.Errorf("failed to get bucket size: %w", err)
|
||
}
|
||
return total, nil
|
||
}
|
||
|
||
// GetObjectCount 获取对象总数
|
||
func (s *Service) GetObjectCount() (int64, error) {
|
||
var count int64
|
||
if err := s.db.Model(&Object{}).Count(&count).Error; err != nil {
|
||
return 0, fmt.Errorf("failed to get object count: %w", err)
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
// GetBucketObjectCount 获取特定存储桶的对象数
|
||
func (s *Service) GetBucketObjectCount(bucketName string) (int64, error) {
|
||
var count int64
|
||
if err := s.db.Model(&Object{}).
|
||
Where("bucket_name = ?", bucketName).
|
||
Count(&count).Error; err != nil {
|
||
return 0, fmt.Errorf("failed to get bucket object count: %w", err)
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
// updateBucketStats 更新存储桶统计信息
|
||
func (s *Service) updateBucketStats(bucketName string) error {
|
||
var stats BucketStats
|
||
|
||
// 计算新的统计数据
|
||
var count int64
|
||
var totalSize int64
|
||
|
||
if err := s.db.Model(&Object{}).
|
||
Where("bucket_name = ?", bucketName).
|
||
Count(&count).Error; err != nil {
|
||
return fmt.Errorf("failed to count objects: %w", err)
|
||
}
|
||
|
||
if err := s.db.Model(&Object{}).
|
||
Where("bucket_name = ?", bucketName).
|
||
Select("COALESCE(SUM(size), 0)").
|
||
Scan(&totalSize).Error; err != nil {
|
||
return fmt.Errorf("failed to sum object sizes: %w", err)
|
||
}
|
||
|
||
// 尝试查找现有记录
|
||
err := s.db.Where("bucket_name = ?", bucketName).First(&stats).Error
|
||
|
||
if err == gorm.ErrRecordNotFound {
|
||
// 记录不存在,创建新记录
|
||
stats = BucketStats{
|
||
BucketName: bucketName,
|
||
ObjectCount: count,
|
||
TotalSize: totalSize,
|
||
LastCheckedAt: time.Now(),
|
||
}
|
||
if err := s.db.Create(&stats).Error; err != nil {
|
||
return fmt.Errorf("failed to create bucket stats: %w", err)
|
||
}
|
||
} else if err != nil {
|
||
// 其他数据库错误
|
||
return fmt.Errorf("failed to query bucket stats: %w", err)
|
||
} else {
|
||
// 记录存在,更新统计数据
|
||
updates := map[string]interface{}{
|
||
"object_count": count,
|
||
"total_size": totalSize,
|
||
"last_checked_at": time.Now(),
|
||
}
|
||
|
||
if err := s.db.Model(&stats).Updates(updates).Error; err != nil {
|
||
return fmt.Errorf("failed to update bucket stats: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// RecordUploadSession 记录上传会话
|
||
func (s *Service) RecordUploadSession(uploadID, key, bucketName string, size int64) error {
|
||
session := &UploadSession{
|
||
UploadID: uploadID,
|
||
Key: key,
|
||
BucketName: bucketName,
|
||
Size: size,
|
||
Status: "pending",
|
||
}
|
||
|
||
if err := s.db.Create(session).Error; err != nil {
|
||
return fmt.Errorf("failed to record upload session: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetUploadSession 获取上传会话
|
||
func (s *Service) GetUploadSession(uploadID string) (*UploadSession, error) {
|
||
var session UploadSession
|
||
if err := s.db.Where("upload_id = ?", uploadID).First(&session).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil, fmt.Errorf("upload session not found: %s", uploadID)
|
||
}
|
||
return nil, fmt.Errorf("failed to get upload session: %w", err)
|
||
}
|
||
return &session, nil
|
||
}
|
||
|
||
// UpdateUploadSession 更新上传会话
|
||
func (s *Service) UpdateUploadSession(uploadID string, completedParts int, status string) error {
|
||
updates := map[string]interface{}{
|
||
"completed_parts": completedParts,
|
||
"status": status,
|
||
"updated_at": time.Now(),
|
||
}
|
||
|
||
if err := s.db.Model(&UploadSession{}).
|
||
Where("upload_id = ?", uploadID).
|
||
Updates(updates).Error; err != nil {
|
||
return fmt.Errorf("failed to update upload session: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// IncrementUploadSessionSize 增加上传会话的大小(用于累加分片大小)
|
||
func (s *Service) IncrementUploadSessionSize(uploadID string, partSize int64) error {
|
||
if err := s.db.Model(&UploadSession{}).
|
||
Where("upload_id = ?", uploadID).
|
||
UpdateColumn("size", gorm.Expr("size + ?", partSize)).Error; err != nil {
|
||
return fmt.Errorf("failed to increment upload session size: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetUploadSessionSize 获取上传会话当前累积的大小
|
||
func (s *Service) GetUploadSessionSize(uploadID string) (int64, error) {
|
||
var session UploadSession
|
||
if err := s.db.Select("size").Where("upload_id = ?", uploadID).First(&session).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return 0, fmt.Errorf("upload session not found: %s", uploadID)
|
||
}
|
||
return 0, fmt.Errorf("failed to get upload session size: %w", err)
|
||
}
|
||
return session.Size, nil
|
||
}
|
||
|
||
// GetPendingUploadSessions 获取正在进行中的上传会话
|
||
func (s *Service) GetPendingUploadSessions(prefix string, keyMarker string, uploadIdMarker string, maxUploads int) ([]*UploadSession, error) {
|
||
query := s.db.Model(&UploadSession{}).Where("status = ?", "pending")
|
||
|
||
// 根据前缀过滤
|
||
if prefix != "" {
|
||
query = query.Where("`key` LIKE ?", prefix+"%")
|
||
}
|
||
|
||
// 分页标记处理
|
||
if keyMarker != "" {
|
||
if uploadIdMarker != "" {
|
||
// 如果同时指定了key和uploadId标记
|
||
query = query.Where("(`key` > ? OR (`key` = ? AND upload_id > ?))", keyMarker, keyMarker, uploadIdMarker)
|
||
} else {
|
||
query = query.Where("`key` > ?", keyMarker)
|
||
}
|
||
}
|
||
|
||
// 限制返回数量
|
||
if maxUploads > 0 {
|
||
query = query.Limit(maxUploads + 1) // 多查询一个以判断是否截断
|
||
}
|
||
|
||
// 按key和uploadID排序
|
||
query = query.Order("`key` ASC, upload_id ASC")
|
||
|
||
var sessions []*UploadSession
|
||
if err := query.Find(&sessions).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to get pending upload sessions: %w", err)
|
||
}
|
||
|
||
return sessions, nil
|
||
}
|
||
|
||
// CleanExpiredSessions 清理过期的上传会话
|
||
func (s *Service) CleanExpiredSessions() error {
|
||
if err := s.db.Where("expires_at < ? AND status = ?", time.Now(), "pending").
|
||
Delete(&UploadSession{}).Error; err != nil {
|
||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// RecordAccessLog 记录访问日志
|
||
func (s *Service) RecordAccessLog(action, key, bucketName, clientIP, userAgent, host string, size int64, success bool, errorMsg string, responseTime int64) error {
|
||
log := &AccessLog{
|
||
Action: action,
|
||
Key: key,
|
||
BucketName: bucketName,
|
||
ClientIP: clientIP,
|
||
UserAgent: userAgent,
|
||
Host: host,
|
||
Size: size,
|
||
Success: success,
|
||
ErrorMsg: errorMsg,
|
||
ResponseTime: responseTime,
|
||
}
|
||
|
||
if err := s.db.Create(log).Error; err != nil {
|
||
return fmt.Errorf("failed to record access log: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetAccessLogs 获取访问日志
|
||
func (s *Service) GetAccessLogs(filter *AccessLogFilter) ([]*AccessLog, error) {
|
||
query := s.db.Model(&AccessLog{})
|
||
|
||
if filter != nil {
|
||
if filter.Action != "" {
|
||
query = query.Where("action = ?", filter.Action)
|
||
}
|
||
if filter.Key != "" {
|
||
query = query.Where("`key` = ?", filter.Key)
|
||
}
|
||
if filter.BucketName != "" {
|
||
query = query.Where("bucket_name = ?", filter.BucketName)
|
||
}
|
||
if filter.ClientIP != "" {
|
||
query = query.Where("client_ip = ?", filter.ClientIP)
|
||
}
|
||
if filter.Success != nil {
|
||
query = query.Where("success = ?", *filter.Success)
|
||
}
|
||
if !filter.StartTime.IsZero() {
|
||
query = query.Where("created_at >= ?", filter.StartTime)
|
||
}
|
||
if !filter.EndTime.IsZero() {
|
||
query = query.Where("created_at <= ?", filter.EndTime)
|
||
}
|
||
if filter.Limit > 0 {
|
||
query = query.Limit(filter.Limit)
|
||
}
|
||
if filter.Offset > 0 {
|
||
query = query.Offset(filter.Offset)
|
||
}
|
||
}
|
||
|
||
var logs []*AccessLog
|
||
if err := query.Order("created_at DESC").Find(&logs).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to get access logs: %w", err)
|
||
}
|
||
|
||
return logs, nil
|
||
}
|
||
|
||
// CreateVirtualBucketMapping 创建虚拟存储桶文件级映射
|
||
func (s *Service) CreateVirtualBucketMapping(virtualBucketName, objectKey, realBucketName string) error {
|
||
mapping := &VirtualBucketMapping{
|
||
VirtualBucketName: virtualBucketName,
|
||
ObjectKey: objectKey,
|
||
RealBucketName: realBucketName,
|
||
}
|
||
|
||
if err := s.db.Create(mapping).Error; err != nil {
|
||
return fmt.Errorf("failed to create virtual bucket mapping: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetVirtualBucketMapping 获取虚拟存储桶文件级映射
|
||
func (s *Service) GetVirtualBucketMapping(virtualBucketName, objectKey string) (*VirtualBucketMapping, error) {
|
||
var mapping VirtualBucketMapping
|
||
if err := s.db.Where("virtual_bucket_name = ? AND object_key = ?", virtualBucketName, objectKey).First(&mapping).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil, fmt.Errorf("virtual bucket mapping not found: %s/%s", virtualBucketName, objectKey)
|
||
}
|
||
return nil, fmt.Errorf("failed to get virtual bucket mapping: %w", err)
|
||
}
|
||
return &mapping, nil
|
||
}
|
||
|
||
// GetVirtualBucketMappings 获取所有虚拟存储桶映射
|
||
func (s *Service) GetVirtualBucketMappings() ([]*VirtualBucketMapping, error) {
|
||
var mappings []*VirtualBucketMapping
|
||
if err := s.db.Find(&mappings).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to get virtual bucket mappings: %w", err)
|
||
}
|
||
return mappings, nil
|
||
}
|
||
|
||
// GetVirtualBucketMappingsForBucket 获取指定虚拟存储桶的所有映射
|
||
func (s *Service) GetVirtualBucketMappingsForBucket(virtualBucketName string) ([]*VirtualBucketMapping, error) {
|
||
var mappings []*VirtualBucketMapping
|
||
if err := s.db.Where("virtual_bucket_name = ?", virtualBucketName).Find(&mappings).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to get virtual bucket mappings for bucket %s: %w", virtualBucketName, err)
|
||
}
|
||
return mappings, nil
|
||
}
|
||
|
||
// UpdateVirtualBucketMapping 更新虚拟存储桶映射
|
||
func (s *Service) UpdateVirtualBucketMapping(virtualBucketName, objectKey, realBucketName string) error {
|
||
updates := map[string]interface{}{
|
||
"real_bucket_name": realBucketName,
|
||
"updated_at": time.Now(),
|
||
}
|
||
|
||
if err := s.db.Model(&VirtualBucketMapping{}).
|
||
Where("virtual_bucket_name = ? AND object_key = ?", virtualBucketName, objectKey).
|
||
Updates(updates).Error; err != nil {
|
||
return fmt.Errorf("failed to update virtual bucket mapping: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DeleteVirtualBucketMapping 删除虚拟存储桶映射
|
||
func (s *Service) DeleteVirtualBucketMapping(virtualBucketName string) error {
|
||
// 删除虚拟存储桶的所有映射
|
||
if err := s.db.Where("virtual_bucket_name = ?", virtualBucketName).
|
||
Delete(&VirtualBucketMapping{}).Error; err != nil {
|
||
return fmt.Errorf("failed to delete virtual bucket mapping: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DeleteVirtualBucketObjectMapping 删除虚拟存储桶中特定对象的映射
|
||
func (s *Service) DeleteVirtualBucketObjectMapping(virtualBucketName, objectKey string) error {
|
||
if err := s.db.Where("virtual_bucket_name = ? AND object_key = ?", virtualBucketName, objectKey).
|
||
Delete(&VirtualBucketMapping{}).Error; err != nil {
|
||
return fmt.Errorf("failed to delete virtual bucket object mapping: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetVirtualBucketObjects 获取虚拟存储桶中的所有对象
|
||
func (s *Service) GetVirtualBucketObjects(virtualBucketName string) ([]*Object, error) {
|
||
// 获取虚拟存储桶的所有文件映射
|
||
mappings, err := s.GetVirtualBucketMappingsForBucket(virtualBucketName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if len(mappings) == 0 {
|
||
return []*Object{}, nil
|
||
}
|
||
|
||
// 收集所有对象键
|
||
objectKeys := make([]string, 0, len(mappings))
|
||
for _, mapping := range mappings {
|
||
objectKeys = append(objectKeys, mapping.ObjectKey)
|
||
}
|
||
|
||
// 从对象表中查询这些对象
|
||
var objects []*Object
|
||
if err := s.db.Where("`key` IN ?", objectKeys).Find(&objects).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to get objects for virtual bucket: %w", err)
|
||
}
|
||
|
||
return objects, nil
|
||
}
|
||
|
||
// DeleteVirtualBucketFileMapping 删除虚拟存储桶文件映射
|
||
func (s *Service) DeleteVirtualBucketFileMapping(virtualBucketName, objectKey string) error {
|
||
if err := s.db.Where("virtual_bucket_name = ? AND object_key = ?", virtualBucketName, objectKey).
|
||
Delete(&VirtualBucketMapping{}).Error; err != nil {
|
||
return fmt.Errorf("failed to delete virtual bucket file mapping: %w", err)
|
||
}
|
||
return nil
|
||
}
|