diff --git a/cmd/s3-balance/main.go b/cmd/s3-balance/main.go index 5d811e6..7839bfc 100644 --- a/cmd/s3-balance/main.go +++ b/cmd/s3-balance/main.go @@ -45,11 +45,14 @@ func main() { } defer database.Close() + // 创建存储服务 + storageService := storage.NewService(database.GetDB()) + // 创建指标服务 metricsService := metrics.New() // 创建存储桶管理器 - bucketManager, err := bucket.NewManager(cfg, metricsService) + bucketManager, err := bucket.NewManager(cfg, metricsService, storageService) if err != nil { log.Fatalf("Failed to create bucket manager: %v", err) } @@ -74,9 +77,6 @@ func main() { 60*time.Minute, // 下载URL有效期 ) - // 创建存储服务 - storageService := storage.NewService(database.GetDB()) - // 启动定期清理过期上传会话的任务 startSessionCleaner(ctx, storageService) diff --git a/internal/api/operation_tracker.go b/internal/api/operation_tracker.go index e894469..822a4e5 100644 --- a/internal/api/operation_tracker.go +++ b/internal/api/operation_tracker.go @@ -15,7 +15,22 @@ func (h *S3Handler) recordBackendOperation(b *bucket.BucketInfo, category bucket if h.metrics != nil { h.metrics.RecordBackendOperation(b.Config.Name, string(category)) } - if disabled := b.RecordOperation(category); disabled { + + var disabled bool + + if h.storage != nil { + newCount, err := h.storage.IncrementBucketOperation(b.Config.Name, string(category)) + if err != nil { + log.Printf("failed to persist backend operation count for bucket %s: %v", b.Config.Name, err) + disabled = b.RecordOperation(category) + } else { + disabled = b.SetOperationCount(category, newCount) + } + } else { + disabled = b.RecordOperation(category) + } + + if disabled { log.Printf("Bucket %s disabled after exceeding %s-type operation limit", b.Config.Name, category) } } diff --git a/internal/bucket/manager.go b/internal/bucket/manager.go index ca3b6ff..41c86ee 100644 --- a/internal/bucket/manager.go +++ b/internal/bucket/manager.go @@ -10,6 +10,7 @@ import ( "github.com/DullJZ/s3-balance/internal/config" "github.com/DullJZ/s3-balance/internal/health" "github.com/DullJZ/s3-balance/internal/metrics" + "github.com/DullJZ/s3-balance/internal/storage" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" @@ -49,15 +50,17 @@ type Manager struct { healthMonitor *health.Monitor statsMonitor *health.StatsMonitor monitorCtx context.Context + storage *storage.Service } // NewManager 创建新的存储桶管理器 -func NewManager(cfg *config.Config, metrics *metrics.Metrics) (*Manager, error) { +func NewManager(cfg *config.Config, metrics *metrics.Metrics, storageService *storage.Service) (*Manager, error) { m := &Manager{ buckets: make(map[string]*BucketInfo), config: cfg, stopChan: make(chan struct{}), metrics: metrics, + storage: storageService, } // 初始化所有存储桶客户端 @@ -84,9 +87,49 @@ func NewManager(cfg *config.Config, metrics *metrics.Metrics) (*Manager, error) // 初始化健康监控 m.initHealthMonitoring() + // 加载持久化的操作计数 + m.loadOperationCounts() + return m, nil } +func (m *Manager) loadOperationCounts() { + if m.storage == nil { + return + } + + counts, err := m.storage.GetBucketOperationCounts() + if err != nil { + log.Printf("Failed to load bucket operation counts: %v", err) + return + } + + m.mu.RLock() + buckets := make(map[string]*BucketInfo, len(m.buckets)) + for name, info := range m.buckets { + buckets[name] = info + } + m.mu.RUnlock() + + for name, info := range buckets { + if info == nil { + continue + } + + oc, ok := counts[name] + if !ok { + continue + } + + if info.SetOperationCount(OperationTypeA, oc.CountA) { + log.Printf("Bucket %s disabled after exceeding A-type operation limit (persisted)", name) + } + if info.SetOperationCount(OperationTypeB, oc.CountB) { + log.Printf("Bucket %s disabled after exceeding B-type operation limit (persisted)", name) + } + } +} + // createS3Client 创建S3客户端 func createS3Client(bucketCfg config.BucketConfig) (*s3.Client, error) { // 创建自定义端点解析器 @@ -324,6 +367,45 @@ func (b *BucketInfo) RecordOperation(category OperationCategory) bool { return false } +// SetOperationCount 设置指定类别的操作计数并检查上限 +func (b *BucketInfo) SetOperationCount(category OperationCategory, value int64) bool { + if b == nil { + return false + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.Config.Virtual { + return false + } + + var limit int64 + + switch category { + case OperationTypeA: + b.operationCountA = value + limit = int64(b.Config.OperationLimits.TypeA) + case OperationTypeB: + b.operationCountB = value + limit = int64(b.Config.OperationLimits.TypeB) + default: + return false + } + + if limit <= 0 { + return false + } + + if !b.operationLimitReached && value >= limit { + b.Available = false + b.operationLimitReached = true + return true + } + + return false +} + // IsVirtual 检查是否为虚拟存储桶 func (b *BucketInfo) IsVirtual() bool { b.mu.RLock() @@ -428,6 +510,8 @@ func (m *Manager) UpdateConfig(newConfig *config.Config) error { m.mu.Unlock() + m.loadOperationCounts() + if restartMonitors { m.startMonitors() } diff --git a/internal/storage/models.go b/internal/storage/models.go index 08a114f..a88837f 100644 --- a/internal/storage/models.go +++ b/internal/storage/models.go @@ -29,13 +29,15 @@ func (Object) TableName() string { // BucketStats 存储桶统计信息模型 type BucketStats struct { - ID uint `gorm:"primaryKey" json:"id"` - BucketName string `gorm:"uniqueIndex;size:255;not null" json:"bucket_name"` - ObjectCount int64 `gorm:"not null;default:0" json:"object_count"` - TotalSize int64 `gorm:"not null;default:0" json:"total_size"` - LastCheckedAt time.Time `json:"last_checked_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uint `gorm:"primaryKey" json:"id"` + BucketName string `gorm:"uniqueIndex;size:255;not null" json:"bucket_name"` + ObjectCount int64 `gorm:"not null;default:0" json:"object_count"` + TotalSize int64 `gorm:"not null;default:0" json:"total_size"` + OperationCountA int64 `gorm:"not null;default:0" json:"operation_count_a"` + OperationCountB int64 `gorm:"not null;default:0" json:"operation_count_b"` + LastCheckedAt time.Time `json:"last_checked_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // TableName 指定表名 diff --git a/internal/storage/service.go b/internal/storage/service.go index 20ad733..9ab629e 100644 --- a/internal/storage/service.go +++ b/internal/storage/service.go @@ -1,6 +1,7 @@ package storage import ( + "errors" "fmt" "time" @@ -12,6 +13,12 @@ type Service struct { db *gorm.DB } +// OperationCounts 后端操作计数 +type OperationCounts struct { + CountA int64 + CountB int64 +} + // NewService 创建新的存储服务 func NewService(db *gorm.DB) *Service { return &Service{ @@ -254,6 +261,84 @@ func (s *Service) updateBucketStats(bucketName string) error { return nil } +func (s *Service) ensureBucketStats(bucketName string) (*BucketStats, error) { + if bucketName == "" { + return nil, fmt.Errorf("bucket name cannot be empty") + } + + stats := &BucketStats{} + err := s.db.Where("bucket_name = ?", bucketName).First(stats).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + stats = &BucketStats{BucketName: bucketName, LastCheckedAt: time.Now()} + if createErr := s.db.Create(stats).Error; createErr != nil { + // 如果在并发创建下出现重复键,忽略并再次查询 + if errors.Is(createErr, gorm.ErrDuplicatedKey) { + if retryErr := s.db.Where("bucket_name = ?", bucketName).First(stats).Error; retryErr != nil { + return nil, fmt.Errorf("failed to fetch bucket stats after duplicate: %w", retryErr) + } + } else { + return nil, fmt.Errorf("failed to create bucket stats: %w", createErr) + } + } + return stats, nil + } else if err != nil { + return nil, fmt.Errorf("failed to fetch bucket stats: %w", err) + } + + return stats, nil +} + +// IncrementBucketOperation 增加指定存储桶的操作计数 +func (s *Service) IncrementBucketOperation(bucketName, category string) (int64, error) { + if _, err := s.ensureBucketStats(bucketName); err != nil { + return 0, err + } + + var field string + switch category { + case "A": + field = "operation_count_a" + case "B": + field = "operation_count_b" + default: + return 0, fmt.Errorf("unknown operation category: %s", category) + } + + if err := s.db.Model(&BucketStats{}). + Where("bucket_name = ?", bucketName). + UpdateColumn(field, gorm.Expr(field+" + ?", 1)).Error; err != nil { + return 0, fmt.Errorf("failed to increment %s for bucket %s: %w", field, bucketName, err) + } + + var count int64 + if err := s.db.Model(&BucketStats{}). + Where("bucket_name = ?", bucketName). + Select(field). + Scan(&count).Error; err != nil { + return 0, fmt.Errorf("failed to fetch updated %s for bucket %s: %w", field, bucketName, err) + } + + return count, nil +} + +// GetBucketOperationCounts 获取所有存储桶的操作计数 +func (s *Service) GetBucketOperationCounts() (map[string]OperationCounts, error) { + var stats []BucketStats + if err := s.db.Find(&stats).Error; err != nil { + return nil, fmt.Errorf("failed to list bucket stats: %w", err) + } + + result := make(map[string]OperationCounts, len(stats)) + for _, st := range stats { + result[st.BucketName] = OperationCounts{ + CountA: st.OperationCountA, + CountB: st.OperationCountB, + } + } + + return result, nil +} + // RecordUploadSession 记录上传会话 func (s *Service) RecordUploadSession(uploadID, key, bucketName string, size int64) error { session := &UploadSession{