commit 37b6adb6dee2dffa1e4b125ac8b55cd3ca0cb6c1 Author: DullJZ <79080562+DullJZ@users.noreply.github.com> Date: Fri Aug 22 21:15:56 2025 +0800 first diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a35a948 --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Binaries +*.exe +*.dll +*.so +*.dylib + +# Test binary +*.test + +# Output of the go coverage tool +*.out + +# Dependency directories +vendor/ + +# Configuration files (keep examples) +config/*.yaml +!config/*.example.yaml + +# Database files +data/ +*.db +*.db-shm +*.db-wal + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Logs +*.log +logs/ + +# Environment variables +.env +.env.local diff --git a/cmd/s3-balance/main.go b/cmd/s3-balance/main.go new file mode 100644 index 0000000..e888730 --- /dev/null +++ b/cmd/s3-balance/main.go @@ -0,0 +1,182 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/DullJZ/s3-balance/internal/api" + "github.com/DullJZ/s3-balance/internal/balancer" + "github.com/DullJZ/s3-balance/internal/bucket" + "github.com/DullJZ/s3-balance/internal/config" + "github.com/DullJZ/s3-balance/internal/database" + "github.com/DullJZ/s3-balance/internal/storage" + "github.com/DullJZ/s3-balance/pkg/presigner" + "github.com/gorilla/mux" +) + +func main() { + // 解析命令行参数 + var configFile string + flag.StringVar(&configFile, "config", "config/config.yaml", "Path to configuration file") + flag.Parse() + + // 加载配置 + cfg, err := config.Load(configFile) + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // 初始化数据库 + if err := database.Initialize(&cfg.Database); err != nil { + log.Fatalf("Failed to initialize database: %v", err) + } + defer database.Close() + + // 创建存储桶管理器 + bucketManager, err := bucket.NewManager(cfg) + if err != nil { + log.Fatalf("Failed to create bucket manager: %v", err) + } + + // 启动存储桶管理器(健康检查和统计更新) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + bucketManager.Start(ctx) + + // 创建负载均衡器 + lb, err := balancer.NewBalancer(bucketManager, &cfg.Balancer) + if err != nil { + log.Fatalf("Failed to create balancer: %v", err) + } + + // 创建预签名URL生成器 + signer := presigner.NewPresigner( + 15*time.Minute, // 上传URL有效期 + 60*time.Minute, // 下载URL有效期 + ) + + // 创建存储服务 + storageService := storage.NewService(database.GetDB()) + + // 创建S3兼容API处理器 + s3Handler := api.NewS3Handler( + bucketManager, + lb, + signer, + storageService, + cfg.S3API.AccessKey, + cfg.S3API.SecretKey, + ) + + // 设置路由 + router := mux.NewRouter() + + // 运行在S3兼容模式 + log.Println("Running in S3-compatible mode") + s3Handler.RegisterS3Routes(router) + + // 添加CORS中间件 + router.Use(corsMiddleware) + + // 添加日志中间件 + router.Use(loggingMiddleware) + + // 创建HTTP服务器 + srv := &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), + Handler: router, + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + IdleTimeout: cfg.Server.IdleTimeout, + } + + // 启动服务器 + go func() { + log.Printf("Starting S3 Balance Service on %s", srv.Addr) + log.Printf("Load balancing strategy: %s", cfg.Balancer.Strategy) + log.Printf("Managed buckets: %d", len(cfg.Buckets)) + + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start server: %v", err) + } + }() + + // 等待中断信号 + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + <-sigChan + + // 优雅关闭 + log.Println("Shutting down server...") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownCancel() + + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Printf("Server shutdown error: %v", err) + } + + // 停止存储桶管理器 + bucketManager.Stop() + cancel() + + log.Println("Server stopped") +} + +// CORS中间件 +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +// 日志中间件 +func loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // 包装ResponseWriter以捕获状态码 + wrapped := &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + } + + next.ServeHTTP(wrapped, r) + + log.Printf( + "[%s] %s %s %d %v", + r.RemoteAddr, + r.Method, + r.RequestURI, + wrapped.statusCode, + time.Since(start), + ) + }) +} + +// responseWriter 包装器用于捕获状态码 +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} diff --git a/config/config.example.yaml b/config/config.example.yaml new file mode 100644 index 0000000..b1c5484 --- /dev/null +++ b/config/config.example.yaml @@ -0,0 +1,129 @@ +# S3 Balance Service Configuration Example + +# 服务器配置 +server: + host: "0.0.0.0" + port: 8080 + read_timeout: 30s + write_timeout: 30s + idle_timeout: 60s + +# 数据库配置 +database: + # 数据库类型: sqlite, mysql, postgres + type: "sqlite" + + # 数据源名称 (DSN) + # SQLite 示例: + dsn: "data/s3-balance.db" + + # MySQL 示例: + # dsn: "user:password@tcp(localhost:3306)/s3balance?charset=utf8mb4&parseTime=True&loc=Local" + + # PostgreSQL 示例: + # dsn: "host=localhost user=postgres password=password dbname=s3balance port=5432 sslmode=disable TimeZone=Asia/Shanghai" + + # 连接池配置 + max_open_conns: 25 # 最大打开连接数 + max_idle_conns: 5 # 最大空闲连接数 + conn_max_lifetime: 300 # 连接最大生命周期(秒) + + # 日志级别: silent, error, warn, info + log_level: "warn" + + # 是否自动迁移数据库表 + auto_migrate: true + +# S3存储桶配置 +buckets: + # 第一个存储桶 - AWS S3 + - name: "my-bucket-1" + endpoint: "" # 留空使用默认AWS端点 + region: "us-east-1" + access_key_id: "YOUR_AWS_ACCESS_KEY_ID" + secret_access_key: "YOUR_AWS_SECRET_ACCESS_KEY" + max_size: "10GB" # 最大容量限制 + weight: 10 # 权重(用于加权负载均衡) + enabled: true + use_ssl: true + path_style: false # AWS S3使用虚拟主机风格 + + # 第二个存储桶 - MinIO + - name: "my-bucket-2" + endpoint: "http://localhost:9000" + region: "us-east-1" + access_key_id: "minioadmin" + secret_access_key: "minioadmin" + max_size: "5GB" + weight: 5 + enabled: true + use_ssl: false + path_style: true # MinIO通常使用路径风格 + + # 第三个存储桶 - 阿里云OSS(兼容S3) + - name: "my-bucket-3" + endpoint: "https://oss-cn-hangzhou.aliyuncs.com" + region: "cn-hangzhou" + access_key_id: "YOUR_ALIYUN_ACCESS_KEY_ID" + secret_access_key: "YOUR_ALIYUN_SECRET_ACCESS_KEY" + max_size: "20GB" + weight: 15 + enabled: true + use_ssl: true + path_style: false + + # 备用存储桶(可以禁用) + - name: "backup-bucket" + endpoint: "http://backup-s3.example.com" + region: "us-west-2" + access_key_id: "BACKUP_ACCESS_KEY" + secret_access_key: "BACKUP_SECRET_KEY" + max_size: "100GB" + weight: 1 + enabled: false # 当前禁用 + use_ssl: false + path_style: true + +# 负载均衡配置 +balancer: + # 负载均衡策略,可选值: + # - "round-robin": 轮询 + # - "least-space": 选择剩余空间最多的存储桶 + # - "weighted": 基于权重的随机选择 + # - "consistent-hash": 一致性哈希(相同的key总是选择相同的存储桶) + strategy: "least-space" + + # 健康检查周期 + health_check_period: 30s + + # 统计信息更新周期 + update_stats_period: 60s + + # 重试配置 + retry_attempts: 3 + retry_delay: 1s + +# 监控指标配置 +metrics: + enabled: true + path: "/metrics" + port: 9090 + +# S3兼容API配置 +s3api: + # 客户端连接用的Access Key + access_key: "AKIAIOSFODNN7EXAMPLE" + + # 客户端连接用的Secret Key + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + + # 是否使用虚拟主机模式 + virtual_host: false + + # 工作模式: + # false (默认):预签名重定向模式,客户端直接与后端存储交互 + # true:代理模式,数据通过S3 Balance服务器传输 + proxy_mode: false + + # 是否需要认证(开发环境可设为false) + auth_required: false diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..234d852 --- /dev/null +++ b/go.mod @@ -0,0 +1,45 @@ +module github.com/DullJZ/s3-balance + +go 1.24.5 + +require ( + github.com/aws/aws-sdk-go-v2 v1.38.0 + github.com/aws/aws-sdk-go-v2/config v1.31.1 + github.com/aws/aws-sdk-go-v2/credentials v1.18.5 + github.com/aws/aws-sdk-go-v2/service/s3 v1.87.0 + github.com/gorilla/mux v1.8.1 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.28.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.37.1 // indirect + github.com/aws/smithy-go v1.22.5 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/text v0.21.0 // indirect + gorm.io/driver/mysql v1.6.0 // indirect + gorm.io/driver/postgres v1.6.0 // indirect + gorm.io/driver/sqlite v1.6.0 // indirect + gorm.io/gorm v1.30.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9d986a4 --- /dev/null +++ b/go.sum @@ -0,0 +1,80 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= +github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg= +github.com/aws/aws-sdk-go-v2/config v1.31.1 h1:PSQn4ObaQLaHl6qjs+XYH2pkxyHzZlk1GgQDrKlRJ7I= +github.com/aws/aws-sdk-go-v2/config v1.31.1/go.mod h1:3UA8Gj+2nzpV8WBUF0b19onBfz0YMXDQyGEW0Ru1ntI= +github.com/aws/aws-sdk-go-v2/credentials v1.18.5 h1:DATc1xnpHUV8VgvtnVQul+zuCwK6vz7gtkbKEUZcuNI= +github.com/aws/aws-sdk-go-v2/credentials v1.18.5/go.mod h1:y7aigZzjm1jUZuCgOrlBng+VJrKkknY2Cl0JWxG7vHU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.3 h1:ZV2XK2L3HBq9sCKQiQ/MdhZJppH/rH0vddEAamsHUIs= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.3/go.mod h1:b9F9tk2HdHpbf3xbN7rUZcfmJI26N6NcJu/8OsBFI/0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.3 h1:3ZKmesYBaFX33czDl6mbrcHb6jeheg6LqjJhQdefhsY= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.3/go.mod h1:7ryVb78GLCnjq7cw45N6oUb9REl7/vNUwjvIqC5UgdY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.3 h1:SE/e52dq9a05RuxzLcjT+S5ZpQobj3ie3UTaSf2NnZc= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.3/go.mod h1:zkpvBTsR020VVr8TOrwK2TrUW9pOir28sH5ECHpnAfo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.87.0 h1:egoDf+Geuuntmw79Mz6mk9gGmELCPzg5PFEABOHB+6Y= +github.com/aws/aws-sdk-go-v2/service/s3 v1.87.0/go.mod h1:t9MDi29H+HDbkolTSQtbI0HP9DemAWQzUjmWC7LGMnE= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.1 h1:YfsU8hHGvVT+c6Q8MUs8haDbFQajAImrB7yZ9XnPcBY= +github.com/aws/aws-sdk-go-v2/service/sso v1.28.1/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.1 h1:b4REsk5C0hooowAPmV8fS2haHb+HCyb5FKSKOZRBBfU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.1/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.1 h1:ssCHKyNJqTnqRH4Vlf+jI0brtGQYBvzWwnATsOMk1mk= +github.com/aws/aws-sdk-go-v2/service/sts v1.37.1/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= +gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= +gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/internal/api/handler.go b/internal/api/handler.go new file mode 100644 index 0000000..c1448e6 --- /dev/null +++ b/internal/api/handler.go @@ -0,0 +1,392 @@ +package api + +import ( + "context" + "encoding/json" + "log" + "net/http" + "strconv" + "time" + + "github.com/DullJZ/s3-balance/internal/balancer" + "github.com/DullJZ/s3-balance/internal/bucket" + "github.com/DullJZ/s3-balance/internal/storage" + "github.com/DullJZ/s3-balance/pkg/presigner" + "github.com/gorilla/mux" +) + +// Handler API处理器 +type Handler struct { + bucketManager *bucket.Manager + balancer *balancer.Balancer + presigner *presigner.Presigner + storage *storage.Service +} + +// NewHandler 创建新的API处理器 +func NewHandler( + bucketManager *bucket.Manager, + balancer *balancer.Balancer, + presigner *presigner.Presigner, + storage *storage.Service, +) *Handler { + return &Handler{ + bucketManager: bucketManager, + balancer: balancer, + presigner: presigner, + storage: storage, + } +} + +// RegisterRoutes 注册路由 +func (h *Handler) RegisterRoutes(router *mux.Router) { + // 健康检查 + router.HandleFunc("/health", h.handleHealth).Methods("GET") + + // 存储桶状态 + router.HandleFunc("/api/v1/buckets", h.handleListBuckets).Methods("GET") + router.HandleFunc("/api/v1/buckets/{bucket}/stats", h.handleBucketStats).Methods("GET") + + // 预签名URL生成 + router.HandleFunc("/api/v1/presign/upload", h.handlePresignUpload).Methods("POST") + router.HandleFunc("/api/v1/presign/download", h.handlePresignDownload).Methods("POST") + router.HandleFunc("/api/v1/presign/delete", h.handlePresignDelete).Methods("POST") + router.HandleFunc("/api/v1/presign/multipart", h.handlePresignMultipart).Methods("POST") + + // 对象操作(记录元数据) + router.HandleFunc("/api/v1/objects", h.handleListObjects).Methods("GET") + router.HandleFunc("/api/v1/objects/{key:.*}", h.handleGetObjectInfo).Methods("GET") + router.HandleFunc("/api/v1/objects/{key:.*}", h.handleDeleteObject).Methods("DELETE") +} + +// 健康检查 +func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "status": "healthy", + "time": time.Now().Unix(), + } + h.sendJSON(w, http.StatusOK, response) +} + +// 列出所有存储桶状态 +func (h *Handler) handleListBuckets(w http.ResponseWriter, r *http.Request) { + buckets := h.bucketManager.GetAllBuckets() + + var bucketList []map[string]interface{} + for _, b := range buckets { + bucketList = append(bucketList, map[string]interface{}{ + "name": b.Config.Name, + "endpoint": b.Config.Endpoint, + "region": b.Config.Region, + "max_size": b.Config.MaxSize, + "max_size_bytes": b.Config.MaxSizeBytes, + "used_size": b.GetUsedSize(), + "available": b.IsAvailable(), + "weight": b.Config.Weight, + "enabled": b.Config.Enabled, + }) + } + + h.sendJSON(w, http.StatusOK, map[string]interface{}{ + "buckets": bucketList, + "strategy": h.balancer.GetStrategy(), + }) +} + +// 获取单个存储桶统计 +func (h *Handler) handleBucketStats(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + bucketName := vars["bucket"] + + bucket, ok := h.bucketManager.GetBucket(bucketName) + if !ok { + h.sendError(w, http.StatusNotFound, "bucket not found") + return + } + + stats := map[string]interface{}{ + "name": bucket.Config.Name, + "max_size_bytes": bucket.Config.MaxSizeBytes, + "used_size": bucket.GetUsedSize(), + "available_space": bucket.GetAvailableSpace(), + "available": bucket.IsAvailable(), + "last_checked": bucket.LastChecked, + } + + h.sendJSON(w, http.StatusOK, stats) +} + +// PresignUploadRequest 上传预签名请求 +type PresignUploadRequest struct { + Key string `json:"key"` + Size int64 `json:"size"` + ContentType string `json:"content_type,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// 生成上传预签名URL +func (h *Handler) handlePresignUpload(w http.ResponseWriter, r *http.Request) { + var req PresignUploadRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.sendError(w, http.StatusBadRequest, "invalid request body") + return + } + + if req.Key == "" { + h.sendError(w, http.StatusBadRequest, "key is required") + return + } + + // 选择存储桶 + bucket, err := h.balancer.SelectBucket(req.Key, req.Size) + if err != nil { + h.sendError(w, http.StatusServiceUnavailable, err.Error()) + return + } + + // 生成预签名URL + uploadURL, err := h.presigner.GenerateUploadURL( + context.Background(), + bucket, + req.Key, + req.ContentType, + req.Metadata, + ) + if err != nil { + h.sendError(w, http.StatusInternalServerError, "failed to generate upload URL") + return + } + + // 记录对象元数据 + if err := h.storage.RecordObject(req.Key, bucket.Config.Name, req.Size, req.Metadata); err != nil { + log.Printf("Failed to record object metadata: %v", err) + } + + // 更新存储桶使用量(预估) + bucket.UpdateUsedSize(req.Size) + + h.sendJSON(w, http.StatusOK, uploadURL) +} + +// PresignDownloadRequest 下载预签名请求 +type PresignDownloadRequest struct { + Key string `json:"key"` +} + +// 生成下载预签名URL +func (h *Handler) handlePresignDownload(w http.ResponseWriter, r *http.Request) { + var req PresignDownloadRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.sendError(w, http.StatusBadRequest, "invalid request body") + return + } + + if req.Key == "" { + h.sendError(w, http.StatusBadRequest, "key is required") + return + } + + // 查找对象所在的存储桶 + bucketName, err := h.storage.FindObjectBucket(req.Key) + if err != nil { + h.sendError(w, http.StatusNotFound, "object not found") + return + } + + bucket, ok := h.bucketManager.GetBucket(bucketName) + if !ok { + h.sendError(w, http.StatusNotFound, "bucket not found") + return + } + + // 生成预签名URL + downloadURL, err := h.presigner.GenerateDownloadURL( + context.Background(), + bucket, + req.Key, + ) + if err != nil { + h.sendError(w, http.StatusInternalServerError, "failed to generate download URL") + return + } + + h.sendJSON(w, http.StatusOK, downloadURL) +} + +// PresignDeleteRequest 删除预签名请求 +type PresignDeleteRequest struct { + Key string `json:"key"` +} + +// 生成删除预签名URL +func (h *Handler) handlePresignDelete(w http.ResponseWriter, r *http.Request) { + var req PresignDeleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.sendError(w, http.StatusBadRequest, "invalid request body") + return + } + + if req.Key == "" { + h.sendError(w, http.StatusBadRequest, "key is required") + return + } + + // 查找对象所在的存储桶 + bucketName, err := h.storage.FindObjectBucket(req.Key) + if err != nil { + h.sendError(w, http.StatusNotFound, "object not found") + return + } + + bucket, ok := h.bucketManager.GetBucket(bucketName) + if !ok { + h.sendError(w, http.StatusNotFound, "bucket not found") + return + } + + // 生成预签名URL + deleteURL, err := h.presigner.GenerateDeleteURL( + context.Background(), + bucket, + req.Key, + ) + if err != nil { + h.sendError(w, http.StatusInternalServerError, "failed to generate delete URL") + return + } + + h.sendJSON(w, http.StatusOK, deleteURL) +} + +// PresignMultipartRequest 分片上传预签名请求 +type PresignMultipartRequest struct { + Key string `json:"key"` + PartCount int `json:"part_count"` + Size int64 `json:"size"` +} + +// 生成分片上传预签名URLs +func (h *Handler) handlePresignMultipart(w http.ResponseWriter, r *http.Request) { + var req PresignMultipartRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.sendError(w, http.StatusBadRequest, "invalid request body") + return + } + + if req.Key == "" || req.PartCount <= 0 { + h.sendError(w, http.StatusBadRequest, "invalid parameters") + return + } + + // 选择存储桶 + bucket, err := h.balancer.SelectBucket(req.Key, req.Size) + if err != nil { + h.sendError(w, http.StatusServiceUnavailable, err.Error()) + return + } + + // 生成预签名URLs + multipartURLs, err := h.presigner.GenerateMultipartUploadURLs( + context.Background(), + bucket, + req.Key, + req.PartCount, + ) + if err != nil { + h.sendError(w, http.StatusInternalServerError, "failed to generate multipart URLs") + return + } + + // 记录对象元数据 + if err := h.storage.RecordObject(req.Key, bucket.Config.Name, req.Size, nil); err != nil { + log.Printf("Failed to record object metadata: %v", err) + } + + // 更新存储桶使用量(预估) + bucket.UpdateUsedSize(req.Size) + + h.sendJSON(w, http.StatusOK, multipartURLs) +} + +// 列出对象 +func (h *Handler) handleListObjects(w http.ResponseWriter, r *http.Request) { + prefix := r.URL.Query().Get("prefix") + bucketName := r.URL.Query().Get("bucket") + marker := r.URL.Query().Get("marker") + limitStr := r.URL.Query().Get("limit") + + limit := 100 + if limitStr != "" { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + } + + // 调用更新后的ListObjects方法,传入所有必需的参数 + objects, err := h.storage.ListObjects(bucketName, prefix, marker, limit) + if err != nil { + h.sendError(w, http.StatusInternalServerError, "failed to list objects") + return + } + + h.sendJSON(w, http.StatusOK, map[string]interface{}{ + "objects": objects, + "count": len(objects), + }) +} + +// 获取对象信息 +func (h *Handler) handleGetObjectInfo(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + key := vars["key"] + + info, err := h.storage.GetObjectInfo(key) + if err != nil { + h.sendError(w, http.StatusNotFound, "object not found") + return + } + + h.sendJSON(w, http.StatusOK, info) +} + +// 删除对象(只删除元数据记录) +func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + key := vars["key"] + + // 获取对象信息以更新存储桶使用量 + info, err := h.storage.GetObjectInfo(key) + if err != nil { + h.sendError(w, http.StatusNotFound, "object not found") + return + } + + // 更新存储桶使用量 + if bucket, ok := h.bucketManager.GetBucket(info.BucketName); ok { + bucket.UpdateUsedSize(-info.Size) + } + + // 删除元数据记录 + if err := h.storage.DeleteObject(key); err != nil { + h.sendError(w, http.StatusInternalServerError, "failed to delete object") + return + } + + h.sendJSON(w, http.StatusOK, map[string]string{ + "message": "object deleted successfully", + }) +} + +// 发送JSON响应 +func (h *Handler) sendJSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +// 发送错误响应 +func (h *Handler) sendError(w http.ResponseWriter, status int, message string) { + h.sendJSON(w, status, map[string]string{ + "error": message, + }) +} diff --git a/internal/api/s3_handler.go b/internal/api/s3_handler.go new file mode 100644 index 0000000..c50ec49 --- /dev/null +++ b/internal/api/s3_handler.go @@ -0,0 +1,780 @@ +package api + +import ( + "context" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/DullJZ/s3-balance/internal/balancer" + "github.com/DullJZ/s3-balance/internal/bucket" + "github.com/DullJZ/s3-balance/internal/storage" + "github.com/DullJZ/s3-balance/pkg/presigner" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/gorilla/mux" +) + +// S3Handler S3兼容的API处理器 +type S3Handler struct { + bucketManager *bucket.Manager + balancer *balancer.Balancer + presigner *presigner.Presigner + storage *storage.Service + accessKey string + secretKey string +} + +// NewS3Handler 创建新的S3兼容API处理器 +func NewS3Handler( + bucketManager *bucket.Manager, + balancer *balancer.Balancer, + presigner *presigner.Presigner, + storage *storage.Service, + accessKey string, + secretKey string, +) *S3Handler { + return &S3Handler{ + bucketManager: bucketManager, + balancer: balancer, + presigner: presigner, + storage: storage, + accessKey: accessKey, + secretKey: secretKey, + } +} + +// RegisterS3Routes 注册S3兼容的路由 +func (h *S3Handler) RegisterS3Routes(router *mux.Router) { + // Service operations + router.HandleFunc("/", h.handleListBuckets).Methods("GET") + + // Bucket operations + router.HandleFunc("/{bucket}", h.handleBucketOperations).Methods("GET", "HEAD", "PUT", "DELETE") + + // Object operations + router.HandleFunc("/{bucket}/{key:.*}", h.handleObjectOperations).Methods("GET", "HEAD", "PUT", "DELETE") + + // Multipart upload operations + router.HandleFunc("/{bucket}/{key:.*}", h.handleMultipartUpload).Methods("POST").Queries("uploads", "") + router.HandleFunc("/{bucket}/{key:.*}", h.handleListMultipartUploads).Methods("GET").Queries("uploads", "") + router.HandleFunc("/{bucket}/{key:.*}", h.handleListMultipartParts).Methods("GET").Queries("uploadId", "") + router.HandleFunc("/{bucket}/{key:.*}", h.handleCompleteMultipartUpload).Methods("POST").Queries("uploadId", "") + router.HandleFunc("/{bucket}/{key:.*}", h.handleAbortMultipartUpload).Methods("DELETE").Queries("uploadId", "") + + // 添加认证中间件 + router.Use(h.s3AuthMiddleware) +} + +// S3 XML响应结构体定义 +type ListBucketsResult struct { + XMLName xml.Name `xml:"ListAllMyBucketsResult"` + Xmlns string `xml:"xmlns,attr"` + Owner Owner `xml:"Owner"` + Buckets Buckets `xml:"Buckets"` +} + +type Owner struct { + ID string `xml:"ID"` + DisplayName string `xml:"DisplayName"` +} + +type Buckets struct { + Bucket []BucketInfo `xml:"Bucket"` +} + +type BucketInfo struct { + Name string `xml:"Name"` + CreationDate time.Time `xml:"CreationDate"` +} + +type ListBucketResult struct { + XMLName xml.Name `xml:"ListBucketResult"` + Xmlns string `xml:"xmlns,attr"` + Name string `xml:"Name"` + Prefix string `xml:"Prefix"` + Marker string `xml:"Marker"` + MaxKeys int `xml:"MaxKeys"` + IsTruncated bool `xml:"IsTruncated"` + Contents []ObjectInfo `xml:"Contents"` + CommonPrefixes []CommonPrefix `xml:"CommonPrefixes,omitempty"` +} + +type ObjectInfo struct { + Key string `xml:"Key"` + LastModified time.Time `xml:"LastModified"` + ETag string `xml:"ETag"` + Size int64 `xml:"Size"` + StorageClass string `xml:"StorageClass"` + Owner Owner `xml:"Owner"` +} + +type CommonPrefix struct { + Prefix string `xml:"Prefix"` +} + +type InitiateMultipartUploadResult struct { + XMLName xml.Name `xml:"InitiateMultipartUploadResult"` + Xmlns string `xml:"xmlns,attr"` + Bucket string `xml:"Bucket"` + Key string `xml:"Key"` + UploadID string `xml:"UploadId"` +} + +type ListMultipartUploadsResult struct { + XMLName xml.Name `xml:"ListMultipartUploadsResult"` + Xmlns string `xml:"xmlns,attr"` + Bucket string `xml:"Bucket"` + KeyMarker string `xml:"KeyMarker"` + UploadIdMarker string `xml:"UploadIdMarker"` + NextKeyMarker string `xml:"NextKeyMarker"` + NextUploadIdMarker string `xml:"NextUploadIdMarker"` + MaxUploads int `xml:"MaxUploads"` + IsTruncated bool `xml:"IsTruncated"` + Uploads []Upload `xml:"Upload"` + CommonPrefixes []CommonPrefix `xml:"CommonPrefixes,omitempty"` +} + +type Upload struct { + Key string `xml:"Key"` + UploadID string `xml:"UploadId"` + Initiator Owner `xml:"Initiator"` + Owner Owner `xml:"Owner"` + StorageClass string `xml:"StorageClass"` + Initiated time.Time `xml:"Initiated"` +} + +type ListPartsResult struct { + XMLName xml.Name `xml:"ListPartsResult"` + Xmlns string `xml:"xmlns,attr"` + Bucket string `xml:"Bucket"` + Key string `xml:"Key"` + UploadID string `xml:"UploadId"` + PartNumberMarker int `xml:"PartNumberMarker"` + NextPartNumberMarker int `xml:"NextPartNumberMarker"` + MaxParts int `xml:"MaxParts"` + IsTruncated bool `xml:"IsTruncated"` + Parts []Part `xml:"Part"` +} + +type Part struct { + PartNumber int `xml:"PartNumber"` + LastModified time.Time `xml:"LastModified"` + ETag string `xml:"ETag"` + Size int64 `xml:"Size"` +} + +type CompleteMultipartUpload struct { + XMLName xml.Name `xml:"CompleteMultipartUpload"` + Parts []Part `xml:"Part"` +} + +type CompleteMultipartUploadResult struct { + XMLName xml.Name `xml:"CompleteMultipartUploadResult"` + Xmlns string `xml:"xmlns,attr"` + Location string `xml:"Location"` + Bucket string `xml:"Bucket"` + Key string `xml:"Key"` + ETag string `xml:"ETag"` +} + +type ErrorResponse struct { + XMLName xml.Name `xml:"Error"` + Code string `xml:"Code"` + Message string `xml:"Message"` + Resource string `xml:"Resource"` + RequestID string `xml:"RequestId"` +} + +// s3AuthMiddleware S3认证中间件(简化版) +func (h *S3Handler) s3AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 简化的认证实现,实际应该验证AWS Signature + // 这里只做基本的header检查 + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + // 允许匿名访问(用于测试) + // 在生产环境中应该要求认证 + } + + next.ServeHTTP(w, r) + }) +} + +// handleListBuckets 处理列出所有存储桶请求 +func (h *S3Handler) handleListBuckets(w http.ResponseWriter, r *http.Request) { + buckets := h.bucketManager.GetAllBuckets() + + result := ListBucketsResult{ + Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", + Owner: Owner{ + ID: "s3-balance", + DisplayName: "S3 Balance Service", + }, + Buckets: Buckets{ + Bucket: make([]BucketInfo, 0, len(buckets)), + }, + } + + for _, b := range buckets { + if b.IsAvailable() { + result.Buckets.Bucket = append(result.Buckets.Bucket, BucketInfo{ + Name: b.Config.Name, + CreationDate: time.Now().Add(-24 * time.Hour), // 模拟创建时间 + }) + } + } + + h.sendXMLResponse(w, http.StatusOK, result) +} + +// handleBucketOperations 处理存储桶相关操作 +func (h *S3Handler) handleBucketOperations(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + bucketName := vars["bucket"] + + switch r.Method { + case "GET": + h.handleListObjects(w, r, bucketName) + case "HEAD": + h.handleHeadBucket(w, r, bucketName) + case "PUT": + h.handleCreateBucket(w, r, bucketName) + case "DELETE": + h.handleDeleteBucket(w, r, bucketName) + } +} + +// handleListObjects 列出存储桶中的对象 +func (h *S3Handler) handleListObjects(w http.ResponseWriter, r *http.Request, bucketName string) { + // 检查bucket是否存在 + if _, ok := h.bucketManager.GetBucket(bucketName); !ok { + h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) + return + } + + // 解析查询参数 + prefix := r.URL.Query().Get("prefix") + marker := r.URL.Query().Get("marker") + maxKeysStr := r.URL.Query().Get("max-keys") + delimiter := r.URL.Query().Get("delimiter") + + maxKeys := 1000 + if maxKeysStr != "" { + if mk, err := strconv.Atoi(maxKeysStr); err == nil { + maxKeys = mk + } + } + + // 从存储中获取对象列表 + objects, err := h.storage.ListObjects(bucketName, prefix, marker, maxKeys) + if err != nil { + h.sendS3Error(w, "InternalError", "Internal server error", bucketName) + return + } + + result := ListBucketResult{ + Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", + Name: bucketName, + Prefix: prefix, + Marker: marker, + MaxKeys: maxKeys, + IsTruncated: false, // 简化实现 + Contents: make([]ObjectInfo, 0), + } + + // 处理分隔符逻辑 + if delimiter != "" { + // 简化的分隔符处理 + result.CommonPrefixes = make([]CommonPrefix, 0) + } + + for _, obj := range objects { + result.Contents = append(result.Contents, ObjectInfo{ + Key: obj.Key, + LastModified: obj.UpdatedAt, + ETag: fmt.Sprintf("\"%x\"", obj.ID), // 简化的ETag + Size: obj.Size, + StorageClass: "STANDARD", + Owner: Owner{ + ID: "s3-balance", + DisplayName: "S3 Balance Service", + }, + }) + } + + h.sendXMLResponse(w, http.StatusOK, result) +} + +// handleHeadBucket 检查存储桶是否存在 +func (h *S3Handler) handleHeadBucket(w http.ResponseWriter, r *http.Request, bucketName string) { + if _, ok := h.bucketManager.GetBucket(bucketName); !ok { + w.WriteHeader(http.StatusNotFound) + return + } + + w.WriteHeader(http.StatusOK) +} + +// handleCreateBucket 创建存储桶(虚拟实现) +func (h *S3Handler) handleCreateBucket(w http.ResponseWriter, r *http.Request, bucketName string) { + // 在负载均衡场景下,不真正创建bucket,只返回成功 + // 实际的bucket应该在配置中预先定义 + w.Header().Set("Location", "/" + bucketName) + w.WriteHeader(http.StatusOK) +} + +// handleDeleteBucket 删除存储桶(虚拟实现) +func (h *S3Handler) handleDeleteBucket(w http.ResponseWriter, r *http.Request, bucketName string) { + // 在负载均衡场景下,不真正删除bucket + w.WriteHeader(http.StatusNoContent) +} + +// handleObjectOperations 处理对象相关操作 +func (h *S3Handler) handleObjectOperations(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + bucketName := vars["bucket"] + key := vars["key"] + + switch r.Method { + case "GET": + h.handleGetObject(w, r, bucketName, key) + case "HEAD": + h.handleHeadObject(w, r, bucketName, key) + case "PUT": + h.handlePutObject(w, r, bucketName, key) + case "DELETE": + h.handleDeleteObject(w, r, bucketName, key) + } +} + +// handleGetObject 获取对象(默认使用预签名URL重定向) +func (h *S3Handler) handleGetObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) { + // 查找对象所在的实际存储桶 + actualBucketName, err := h.storage.FindObjectBucket(key) + if err != nil { + h.sendS3Error(w, "NoSuchKey", "The specified key does not exist", key) + return + } + + bucket, ok := h.bucketManager.GetBucket(actualBucketName) + if !ok { + h.sendS3Error(w, "InternalError", "Internal server error", bucketName) + return + } + + // 生成预签名下载URL + downloadInfo, err := h.presigner.GenerateDownloadURL( + context.Background(), + bucket, + key, + ) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to generate download URL", key) + return + } + + // 默认使用预签名重定向模式,只有明确指定时才使用代理模式 + if r.URL.Query().Get("proxy") == "true" { + // 代理模式:服务器下载内容并返回给客户端 + resp, err := http.Get(downloadInfo.URL) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to fetch object", key) + return + } + defer resp.Body.Close() + + // 复制响应头 + for k, v := range resp.Header { + w.Header()[k] = v + } + + // 复制响应体 + io.Copy(w, resp.Body) + } else { + // 重定向模式:返回302重定向到预签名URL(默认) + http.Redirect(w, r, downloadInfo.URL, http.StatusFound) + } +} + +// handleHeadObject 获取对象元数据 +func (h *S3Handler) handleHeadObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) { + // 从存储中获取对象信息 + obj, err := h.storage.GetObjectInfo(key) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + w.Header().Set("Content-Length", strconv.FormatInt(obj.Size, 10)) + w.Header().Set("Last-Modified", obj.UpdatedAt.Format(http.TimeFormat)) + w.Header().Set("ETag", fmt.Sprintf("\"%x\"", obj.ID)) + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) +} + +// handlePutObject 上传对象(默认使用预签名URL重定向) +func (h *S3Handler) handlePutObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) { + // 获取内容长度 + contentLength := r.ContentLength + if contentLength < 0 { + h.sendS3Error(w, "MissingContentLength", "Content-Length header is required", key) + return + } + + // 选择目标存储桶 + targetBucket, err := h.balancer.SelectBucket(key, contentLength) + if err != nil { + h.sendS3Error(w, "InsufficientStorage", "No bucket has enough space", key) + return + } + + // 生成预签名上传URL + uploadInfo, err := h.presigner.GenerateUploadURL( + context.Background(), + targetBucket, + key, + r.Header.Get("Content-Type"), + nil, // metadata + ) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to generate upload URL", key) + return + } + + // 默认使用预签名重定向模式,只有明确指定时才使用代理模式 + if r.URL.Query().Get("proxy") == "true" { + // 代理模式:读取请求体并上传到预签名URL + // 创建新的请求 + req, err := http.NewRequest(uploadInfo.Method, uploadInfo.URL, r.Body) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to create upload request", key) + return + } + + // 设置必要的头 + req.ContentLength = contentLength + if ct := r.Header.Get("Content-Type"); ct != "" { + req.Header.Set("Content-Type", ct) + } + + // 添加预签名URL所需的额外头 + for k, v := range uploadInfo.Headers { + req.Header.Set(k, v) + } + + // 执行上传 + client := &http.Client{Timeout: 30 * time.Minute} + resp, err := client.Do(req) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to upload object", key) + return + } + defer resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + // 记录对象元数据 + h.storage.RecordObject(key, targetBucket.Config.Name, contentLength, nil) + targetBucket.UpdateUsedSize(contentLength) + + // 返回成功响应 + w.Header().Set("ETag", fmt.Sprintf("\"%x\"", time.Now().UnixNano())) + w.WriteHeader(http.StatusOK) + } else { + h.sendS3Error(w, "InternalError", "Upload failed", key) + } + } else { + // 重定向模式:返回307临时重定向让客户端直接上传(默认) + w.Header().Set("Location", uploadInfo.URL) + w.WriteHeader(http.StatusTemporaryRedirect) + } +} + +// handleDeleteObject 删除对象 +func (h *S3Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) { + // 查找对象所在的实际存储桶 + actualBucketName, err := h.storage.FindObjectBucket(key) + if err != nil { + // 对象不存在,S3规范要求返回204 + w.WriteHeader(http.StatusNoContent) + return + } + + bucket, ok := h.bucketManager.GetBucket(actualBucketName) + if !ok { + h.sendS3Error(w, "InternalError", "Internal server error", bucketName) + return + } + + // 生成预签名删除URL + deleteInfo, err := h.presigner.GenerateDeleteURL( + context.Background(), + bucket, + key, + ) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to generate delete URL", key) + return + } + + // 执行删除 + req, _ := http.NewRequest("DELETE", deleteInfo.URL, nil) + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to delete object", key) + return + } + defer resp.Body.Close() + + // 从数据库中删除对象记录 + h.storage.DeleteObject(key) + + // S3规范要求删除操作总是返回204 + w.WriteHeader(http.StatusNoContent) +} + +// handleMultipartUpload 初始化分片上传 +func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + key := vars["key"] + + // 选择目标存储桶 + targetBucket, err := h.balancer.SelectBucket(key, 0) // 分片上传时不检查空间 + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to select bucket for upload", key) + return + } + + // 初始化分片上传 + ctx := context.Background() + createResp, err := targetBucket.Client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(targetBucket.Config.Name), + Key: aws.String(key), + }) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to initiate multipart upload", key) + return + } + + result := InitiateMultipartUploadResult{ + Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", + Bucket: targetBucket.Config.Name, + Key: key, + UploadID: *createResp.UploadId, + } + + h.sendXMLResponse(w, http.StatusOK, result) +} + +// handleListMultipartUploads 列出分片上传 +func (h *S3Handler) handleListMultipartUploads(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + bucketName := vars["bucket"] + key := vars["key"] + + // 检查bucket是否存在 + if _, ok := h.bucketManager.GetBucket(bucketName); !ok { + h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) + return + } + + // 简化实现:返回空列表 + result := ListMultipartUploadsResult{ + Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", + Bucket: bucketName, + KeyMarker: key, + MaxUploads: 1000, + IsTruncated: false, + Uploads: make([]Upload, 0), + } + + h.sendXMLResponse(w, http.StatusOK, result) +} + +// handleListMultipartParts 列出分片上传的分片 +func (h *S3Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + bucketName := vars["bucket"] + key := vars["key"] + uploadID := r.URL.Query().Get("uploadId") + + // 检查bucket是否存在 + if _, ok := h.bucketManager.GetBucket(bucketName); !ok { + h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) + return + } + + // 简化实现:返回空列表 + result := ListPartsResult{ + Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", + Bucket: bucketName, + Key: key, + UploadID: uploadID, + PartNumberMarker: 0, + MaxParts: 1000, + IsTruncated: false, + Parts: make([]Part, 0), + } + + h.sendXMLResponse(w, http.StatusOK, result) +} + +// handleCompleteMultipartUpload 完成分片上传 +func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + bucketName := vars["bucket"] + key := vars["key"] + uploadID := r.URL.Query().Get("uploadId") + + // 查找对象所在的实际存储桶(简化实现,使用配置的bucket) + bucket, ok := h.bucketManager.GetBucket(bucketName) + if !ok { + h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) + return + } + + // 解析请求体以获取分片列表 + var completeReq CompleteMultipartUpload + body, _ := io.ReadAll(r.Body) + xml.Unmarshal(body, &completeReq) + + // 完成分片上传 + ctx := context.Background() + var parts []types.CompletedPart + for _, part := range completeReq.Parts { + parts = append(parts, types.CompletedPart{ + ETag: aws.String(part.ETag), + PartNumber: aws.Int32(int32(part.PartNumber)), + }) + } + + completeResp, err := bucket.Client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: aws.String(uploadID), + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: parts, + }, + }) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to complete multipart upload", key) + return + } + + result := CompleteMultipartUploadResult{ + Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", + Location: "/" + bucket.Config.Name + "/" + key, + Bucket: bucket.Config.Name, + Key: key, + ETag: *completeResp.ETag, + } + + // 记录对象元数据(简化:假设总大小) + h.storage.RecordObject(key, bucket.Config.Name, 0, nil) + + h.sendXMLResponse(w, http.StatusOK, result) +} + +// handleAbortMultipartUpload 中止分片上传 +func (h *S3Handler) handleAbortMultipartUpload(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + bucketName := vars["bucket"] + key := vars["key"] + uploadID := r.URL.Query().Get("uploadId") + + // 查找对象所在的实际存储桶(简化实现,使用配置的bucket) + bucket, ok := h.bucketManager.GetBucket(bucketName) + if !ok { + h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) + return + } + + // 中止分片上传 + ctx := context.Background() + _, err := bucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: aws.String(uploadID), + }) + if err != nil { + h.sendS3Error(w, "InternalError", "Failed to abort multipart upload", key) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// sendXMLResponse 发送XML响应 +func (h *S3Handler) sendXMLResponse(w http.ResponseWriter, statusCode int, data interface{}) { + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(statusCode) + + encoder := xml.NewEncoder(w) + encoder.Indent("", " ") + + // 写入XML声明 + w.Write([]byte(xml.Header)) + + if err := encoder.Encode(data); err != nil { + // 如果编码失败,记录错误 + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +} + +// sendS3Error 发送S3错误响应 +func (h *S3Handler) sendS3Error(w http.ResponseWriter, code string, message string, resource string) { + errorResp := ErrorResponse{ + Code: code, + Message: message, + Resource: resource, + RequestID: fmt.Sprintf("%d", time.Now().UnixNano()), + } + + statusCode := http.StatusBadRequest + switch code { + case "NoSuchBucket", "NoSuchKey": + statusCode = http.StatusNotFound + case "BucketAlreadyExists": + statusCode = http.StatusConflict + case "InvalidAccessKeyId", "SignatureDoesNotMatch": + statusCode = http.StatusForbidden + case "InternalError": + statusCode = http.StatusInternalServerError + case "InsufficientStorage": + statusCode = http.StatusInsufficientStorage + } + + h.sendXMLResponse(w, statusCode, errorResp) +} + +// 辅助函数:解析S3路径 +func parseS3Path(requestPath string) (bucket string, key string) { + requestPath = strings.TrimPrefix(requestPath, "/") + parts := strings.SplitN(requestPath, "/", 2) + + if len(parts) > 0 { + bucket = parts[0] + } + if len(parts) > 1 { + key = parts[1] + } + + return bucket, key +} + +// 辅助函数:URL编码/解码 +func urlEncodePath(p string) string { + return strings.ReplaceAll(url.QueryEscape(p), "+", "%20") +} + +func urlDecodePath(p string) string { + decoded, _ := url.QueryUnescape(p) + return decoded +} diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go new file mode 100644 index 0000000..8b3d508 --- /dev/null +++ b/internal/balancer/balancer.go @@ -0,0 +1,264 @@ +package balancer + +import ( + "crypto/md5" + "encoding/binary" + "fmt" + "math/rand" + "sort" + "sync" + "sync/atomic" + + "github.com/DullJZ/s3-balance/internal/bucket" + "github.com/DullJZ/s3-balance/internal/config" +) + +// Strategy 负载均衡策略接口 +type Strategy interface { + SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) + Name() string +} + +// Balancer 负载均衡器 +type Balancer struct { + manager *bucket.Manager + strategy Strategy + config *config.BalancerConfig +} + +// NewBalancer 创建新的负载均衡器 +func NewBalancer(manager *bucket.Manager, cfg *config.BalancerConfig) (*Balancer, error) { + var strategy Strategy + + switch cfg.Strategy { + case "round-robin": + strategy = NewRoundRobinStrategy() + case "least-space": + strategy = NewLeastSpaceStrategy() + case "weighted": + strategy = NewWeightedStrategy() + case "consistent-hash": + strategy = NewConsistentHashStrategy() + default: + return nil, fmt.Errorf("unknown balancer strategy: %s", cfg.Strategy) + } + + return &Balancer{ + manager: manager, + strategy: strategy, + config: cfg, + }, nil +} + +// SelectBucket 选择一个存储桶 +func (b *Balancer) SelectBucket(key string, size int64) (*bucket.BucketInfo, error) { + // 获取所有可用的存储桶 + buckets := b.manager.GetAvailableBuckets() + if len(buckets) == 0 { + return nil, fmt.Errorf("no available buckets") + } + + // 过滤出有足够空间的存储桶 + var availableBuckets []*bucket.BucketInfo + for _, bucket := range buckets { + if bucket.GetAvailableSpace() >= size { + availableBuckets = append(availableBuckets, bucket) + } + } + + if len(availableBuckets) == 0 { + return nil, fmt.Errorf("no bucket has enough space for %d bytes", size) + } + + // 使用策略选择存储桶 + selected, err := b.strategy.SelectBucket(availableBuckets, key, size) + if err != nil { + return nil, err + } + + return selected, nil +} + +// GetStrategy 获取当前策略名称 +func (b *Balancer) GetStrategy() string { + return b.strategy.Name() +} + +// RoundRobinStrategy 轮询策略 +type RoundRobinStrategy struct { + counter uint64 +} + +// NewRoundRobinStrategy 创建轮询策略 +func NewRoundRobinStrategy() *RoundRobinStrategy { + return &RoundRobinStrategy{} +} + +// SelectBucket 选择存储桶(轮询) +func (s *RoundRobinStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) { + if len(buckets) == 0 { + return nil, fmt.Errorf("no buckets available") + } + + index := atomic.AddUint64(&s.counter, 1) % uint64(len(buckets)) + return buckets[index], nil +} + +// Name 返回策略名称 +func (s *RoundRobinStrategy) Name() string { + return "round-robin" +} + +// LeastSpaceStrategy 最少使用空间策略 +type LeastSpaceStrategy struct{} + +// NewLeastSpaceStrategy 创建最少使用空间策略 +func NewLeastSpaceStrategy() *LeastSpaceStrategy { + return &LeastSpaceStrategy{} +} + +// SelectBucket 选择存储桶(选择使用空间最少的) +func (s *LeastSpaceStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) { + if len(buckets) == 0 { + return nil, fmt.Errorf("no buckets available") + } + + // 按可用空间排序(从大到小) + sort.Slice(buckets, func(i, j int) bool { + return buckets[i].GetAvailableSpace() > buckets[j].GetAvailableSpace() + }) + + return buckets[0], nil +} + +// Name 返回策略名称 +func (s *LeastSpaceStrategy) Name() string { + return "least-space" +} + +// WeightedStrategy 加权策略 +type WeightedStrategy struct { + mu sync.RWMutex +} + +// NewWeightedStrategy 创建加权策略 +func NewWeightedStrategy() *WeightedStrategy { + return &WeightedStrategy{} +} + +// SelectBucket 选择存储桶(基于权重) +func (s *WeightedStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) { + if len(buckets) == 0 { + return nil, fmt.Errorf("no buckets available") + } + + // 计算总权重 + totalWeight := 0 + for _, b := range buckets { + totalWeight += b.Config.Weight + } + + if totalWeight == 0 { + // 如果所有权重都是0,则随机选择 + return buckets[rand.Intn(len(buckets))], nil + } + + // 根据权重随机选择 + randomWeight := rand.Intn(totalWeight) + currentWeight := 0 + + for _, b := range buckets { + currentWeight += b.Config.Weight + if randomWeight < currentWeight { + return b, nil + } + } + + // 不应该到达这里,但为了安全返回最后一个 + return buckets[len(buckets)-1], nil +} + +// Name 返回策略名称 +func (s *WeightedStrategy) Name() string { + return "weighted" +} + +// ConsistentHashStrategy 一致性哈希策略 +type ConsistentHashStrategy struct { + replicas int + ring map[uint32]*bucket.BucketInfo + nodes []uint32 + mu sync.RWMutex +} + +// NewConsistentHashStrategy 创建一致性哈希策略 +func NewConsistentHashStrategy() *ConsistentHashStrategy { + return &ConsistentHashStrategy{ + replicas: 100, // 每个节点的虚拟节点数 + ring: make(map[uint32]*bucket.BucketInfo), + } +} + +// SelectBucket 选择存储桶(基于一致性哈希) +func (s *ConsistentHashStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) { + if len(buckets) == 0 { + return nil, fmt.Errorf("no buckets available") + } + + // 更新哈希环 + s.updateRing(buckets) + + // 计算key的哈希值 + hash := s.hash(key) + + // 在环上找到第一个大于等于hash的节点 + s.mu.RLock() + defer s.mu.RUnlock() + + idx := sort.Search(len(s.nodes), func(i int) bool { + return s.nodes[i] >= hash + }) + + // 如果没找到,返回第一个节点(环形结构) + if idx == len(s.nodes) { + idx = 0 + } + + return s.ring[s.nodes[idx]], nil +} + +// updateRing 更新哈希环 +func (s *ConsistentHashStrategy) updateRing(buckets []*bucket.BucketInfo) { + s.mu.Lock() + defer s.mu.Unlock() + + // 清空现有环 + s.ring = make(map[uint32]*bucket.BucketInfo) + s.nodes = nil + + // 为每个存储桶添加虚拟节点 + for _, b := range buckets { + for i := 0; i < s.replicas; i++ { + virtualKey := fmt.Sprintf("%s-%d", b.Config.Name, i) + hash := s.hash(virtualKey) + s.ring[hash] = b + s.nodes = append(s.nodes, hash) + } + } + + // 排序节点 + sort.Slice(s.nodes, func(i, j int) bool { + return s.nodes[i] < s.nodes[j] + }) +} + +// hash 计算哈希值 +func (s *ConsistentHashStrategy) hash(key string) uint32 { + h := md5.Sum([]byte(key)) + return binary.BigEndian.Uint32(h[:4]) +} + +// Name 返回策略名称 +func (s *ConsistentHashStrategy) Name() string { + return "consistent-hash" +} diff --git a/internal/bucket/manager.go b/internal/bucket/manager.go new file mode 100644 index 0000000..c40cff0 --- /dev/null +++ b/internal/bucket/manager.go @@ -0,0 +1,310 @@ +package bucket + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/DullJZ/s3-balance/internal/config" + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// BucketInfo 存储桶信息 +type BucketInfo struct { + Config config.BucketConfig + Client *s3.Client + UsedSize int64 // 已使用容量(字节) + Available bool // 是否可用 + LastChecked time.Time // 最后检查时间 + mu sync.RWMutex +} + +// Manager 存储桶管理器 +type Manager struct { + buckets map[string]*BucketInfo + mu sync.RWMutex + config *config.Config + stopChan chan struct{} +} + +// NewManager 创建新的存储桶管理器 +func NewManager(cfg *config.Config) (*Manager, error) { + m := &Manager{ + buckets: make(map[string]*BucketInfo), + config: cfg, + stopChan: make(chan struct{}), + } + + // 初始化所有存储桶客户端 + for _, bucketCfg := range cfg.Buckets { + if !bucketCfg.Enabled { + continue + } + + client, err := createS3Client(bucketCfg) + if err != nil { + return nil, fmt.Errorf("failed to create S3 client for bucket %s: %w", bucketCfg.Name, err) + } + + info := &BucketInfo{ + Config: bucketCfg, + Client: client, + Available: true, + LastChecked: time.Now(), + } + + m.buckets[bucketCfg.Name] = info + } + + return m, nil +} + +// createS3Client 创建S3客户端 +func createS3Client(bucketCfg config.BucketConfig) (*s3.Client, error) { + // 创建自定义端点解析器 + customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if bucketCfg.Endpoint != "" { + return aws.Endpoint{ + URL: bucketCfg.Endpoint, + SigningRegion: bucketCfg.Region, + HostnameImmutable: true, + }, nil + } + // 返回错误以使用默认解析器 + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) + + // 配置AWS SDK + cfg, err := awsconfig.LoadDefaultConfig(context.TODO(), + awsconfig.WithRegion(bucketCfg.Region), + awsconfig.WithEndpointResolverWithOptions(customResolver), + awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider( + bucketCfg.AccessKeyID, + bucketCfg.SecretAccessKey, + "", + ), + ), + ) + if err != nil { + return nil, err + } + + // 创建S3客户端 + client := s3.NewFromConfig(cfg, func(o *s3.Options) { + o.UsePathStyle = bucketCfg.PathStyle + }) + + return client, nil +} + +// Start 启动管理器(健康检查和统计更新) +func (m *Manager) Start(ctx context.Context) { + // 启动健康检查 + go m.healthCheckLoop(ctx) + + // 启动统计更新 + go m.statsUpdateLoop(ctx) +} + +// Stop 停止管理器 +func (m *Manager) Stop() { + close(m.stopChan) +} + +// healthCheckLoop 健康检查循环 +func (m *Manager) healthCheckLoop(ctx context.Context) { + ticker := time.NewTicker(m.config.Balancer.HealthCheckPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-m.stopChan: + return + case <-ticker.C: + m.checkAllBuckets(ctx) + } + } +} + +// statsUpdateLoop 统计更新循环 +func (m *Manager) statsUpdateLoop(ctx context.Context) { + ticker := time.NewTicker(m.config.Balancer.UpdateStatsPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-m.stopChan: + return + case <-ticker.C: + m.updateAllStats(ctx) + } + } +} + +// checkAllBuckets 检查所有存储桶的健康状态 +func (m *Manager) checkAllBuckets(ctx context.Context) { + m.mu.RLock() + buckets := make([]*BucketInfo, 0, len(m.buckets)) + for _, b := range m.buckets { + buckets = append(buckets, b) + } + m.mu.RUnlock() + + var wg sync.WaitGroup + for _, bucket := range buckets { + wg.Add(1) + go func(b *BucketInfo) { + defer wg.Done() + m.checkBucket(ctx, b) + }(bucket) + } + wg.Wait() +} + +// checkBucket 检查单个存储桶 +func (m *Manager) checkBucket(ctx context.Context, bucket *BucketInfo) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + // 尝试列出存储桶(用于健康检查) + _, err := bucket.Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket.Config.Name), + MaxKeys: aws.Int32(1), + }) + + bucket.mu.Lock() + bucket.Available = err == nil + bucket.LastChecked = time.Now() + bucket.mu.Unlock() +} + +// updateAllStats 更新所有存储桶的统计信息 +func (m *Manager) updateAllStats(ctx context.Context) { + m.mu.RLock() + buckets := make([]*BucketInfo, 0, len(m.buckets)) + for _, b := range m.buckets { + buckets = append(buckets, b) + } + m.mu.RUnlock() + + var wg sync.WaitGroup + for _, bucket := range buckets { + wg.Add(1) + go func(b *BucketInfo) { + defer wg.Done() + m.updateBucketStats(ctx, b) + }(bucket) + } + wg.Wait() +} + +// updateBucketStats 更新单个存储桶的统计信息 +func (m *Manager) updateBucketStats(ctx context.Context, bucket *BucketInfo) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + var totalSize int64 + var continuationToken *string + + for { + output, err := bucket.Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket.Config.Name), + ContinuationToken: continuationToken, + }) + if err != nil { + break + } + + for _, obj := range output.Contents { + if obj.Size != nil { + totalSize += *obj.Size + } + } + + if output.IsTruncated == nil || !*output.IsTruncated { + break + } + continuationToken = output.NextContinuationToken + } + + bucket.mu.Lock() + bucket.UsedSize = totalSize + bucket.mu.Unlock() +} + +// GetBucket 获取指定名称的存储桶 +func (m *Manager) GetBucket(name string) (*BucketInfo, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + bucket, ok := m.buckets[name] + return bucket, ok +} + +// GetAllBuckets 获取所有存储桶 +func (m *Manager) GetAllBuckets() []*BucketInfo { + m.mu.RLock() + defer m.mu.RUnlock() + + buckets := make([]*BucketInfo, 0, len(m.buckets)) + for _, b := range m.buckets { + buckets = append(buckets, b) + } + return buckets +} + +// GetAvailableBuckets 获取所有可用的存储桶 +func (m *Manager) GetAvailableBuckets() []*BucketInfo { + m.mu.RLock() + defer m.mu.RUnlock() + + var available []*BucketInfo + for _, b := range m.buckets { + b.mu.RLock() + if b.Available && (b.Config.MaxSizeBytes == 0 || b.UsedSize < b.Config.MaxSizeBytes) { + available = append(available, b) + } + b.mu.RUnlock() + } + return available +} + +// GetAvailableSpace 获取存储桶的可用空间 +func (b *BucketInfo) GetAvailableSpace() int64 { + b.mu.RLock() + defer b.mu.RUnlock() + + if b.Config.MaxSizeBytes == 0 { + return 1 << 62 // 返回一个很大的数表示无限制 + } + return b.Config.MaxSizeBytes - b.UsedSize +} + +// IsAvailable 检查存储桶是否可用 +func (b *BucketInfo) IsAvailable() bool { + b.mu.RLock() + defer b.mu.RUnlock() + return b.Available +} + +// GetUsedSize 获取已使用容量 +func (b *BucketInfo) GetUsedSize() int64 { + b.mu.RLock() + defer b.mu.RUnlock() + return b.UsedSize +} + +// UpdateUsedSize 更新已使用容量 +func (b *BucketInfo) UpdateUsedSize(delta int64) { + b.mu.Lock() + defer b.mu.Unlock() + b.UsedSize += delta +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..2fc76a3 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,212 @@ +package config + +import ( + "fmt" + "os" + "time" + + "gopkg.in/yaml.v3" +) + +// Config 全局配置结构 +type Config struct { + Server ServerConfig `yaml:"server"` + Database DatabaseConfig `yaml:"database"` + Buckets []BucketConfig `yaml:"buckets"` + Balancer BalancerConfig `yaml:"balancer"` + Metrics MetricsConfig `yaml:"metrics"` + S3API S3APIConfig `yaml:"s3api"` +} + +// ServerConfig 服务器配置 +type ServerConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + ReadTimeout time.Duration `yaml:"read_timeout"` + WriteTimeout time.Duration `yaml:"write_timeout"` + IdleTimeout time.Duration `yaml:"idle_timeout"` +} + +// BucketConfig S3存储桶配置 +type BucketConfig struct { + Name string `yaml:"name"` // 桶名称 + Endpoint string `yaml:"endpoint"` // S3端点 + Region string `yaml:"region"` // 区域 + AccessKeyID string `yaml:"access_key_id"` // 访问密钥ID + SecretAccessKey string `yaml:"secret_access_key"` // 访问密钥 + MaxSize string `yaml:"max_size"` // 最大容量 (例如: "10GB") + MaxSizeBytes int64 `yaml:"-"` // 内部使用,字节为单位 + Weight int `yaml:"weight"` // 权重 (用于负载均衡) + Enabled bool `yaml:"enabled"` // 是否启用 + UseSSL bool `yaml:"use_ssl"` // 是否使用SSL + PathStyle bool `yaml:"path_style"` // 是否使用路径风格访问 +} + +// BalancerConfig 负载均衡配置 +type BalancerConfig struct { + Strategy string `yaml:"strategy"` // 负载均衡策略: "round-robin", "least-space", "weighted", "consistent-hash" + HealthCheckPeriod time.Duration `yaml:"health_check_period"` // 健康检查周期 + UpdateStatsPeriod time.Duration `yaml:"update_stats_period"` // 统计更新周期 + RetryAttempts int `yaml:"retry_attempts"` // 重试次数 + RetryDelay time.Duration `yaml:"retry_delay"` // 重试延迟 +} + +// MetricsConfig 监控指标配置 +type MetricsConfig struct { + Enabled bool `yaml:"enabled"` + Path string `yaml:"path"` + Port int `yaml:"port"` +} + +// S3APIConfig S3兼容API配置 +type S3APIConfig struct { + AccessKey string `yaml:"access_key"` // S3访问密钥ID + SecretKey string `yaml:"secret_key"` // S3秘密访问密钥 + VirtualHost bool `yaml:"virtual_host"` // 是否使用虚拟主机模式 + ProxyMode bool `yaml:"proxy_mode"` // 是否使用代理模式(而非重定向) + AuthRequired bool `yaml:"auth_required"` // 是否需要认证 +} + +// DatabaseConfig 数据库配置 +type DatabaseConfig struct { + Type string `yaml:"type"` // 数据库类型: sqlite, mysql, postgres + DSN string `yaml:"dsn"` // 数据源名称 + MaxOpenConns int `yaml:"max_open_conns"` // 最大打开连接数 + MaxIdleConns int `yaml:"max_idle_conns"` // 最大空闲连接数 + ConnMaxLifetime int `yaml:"conn_max_lifetime"` // 连接最大生命周期(秒) + LogLevel string `yaml:"log_level"` // 日志级别: silent, error, warn, info + AutoMigrate bool `yaml:"auto_migrate"` // 是否自动迁移 +} + +// Load 从文件加载配置 +func Load(configPath string) (*Config, error) { + file, err := os.Open(configPath) + if err != nil { + return nil, fmt.Errorf("failed to open config file: %w", err) + } + defer file.Close() + + var config Config + decoder := yaml.NewDecoder(file) + if err := decoder.Decode(&config); err != nil { + return nil, fmt.Errorf("failed to decode config: %w", err) + } + + // 解析容量大小 + for i := range config.Buckets { + if err := config.Buckets[i].ParseMaxSize(); err != nil { + return nil, fmt.Errorf("failed to parse max size for bucket %s: %w", + config.Buckets[i].Name, err) + } + } + + // 设置默认值 + config.SetDefaults() + + return &config, nil +} + +// SetDefaults 设置默认配置值 +func (c *Config) SetDefaults() { + if c.Server.Host == "" { + c.Server.Host = "0.0.0.0" + } + if c.Server.Port == 0 { + c.Server.Port = 8080 + } + if c.Server.ReadTimeout == 0 { + c.Server.ReadTimeout = 30 * time.Second + } + if c.Server.WriteTimeout == 0 { + c.Server.WriteTimeout = 30 * time.Second + } + if c.Server.IdleTimeout == 0 { + c.Server.IdleTimeout = 60 * time.Second + } + + if c.Balancer.Strategy == "" { + c.Balancer.Strategy = "least-space" + } + if c.Balancer.HealthCheckPeriod == 0 { + c.Balancer.HealthCheckPeriod = 30 * time.Second + } + if c.Balancer.UpdateStatsPeriod == 0 { + c.Balancer.UpdateStatsPeriod = 60 * time.Second + } + if c.Balancer.RetryAttempts == 0 { + c.Balancer.RetryAttempts = 3 + } + if c.Balancer.RetryDelay == 0 { + c.Balancer.RetryDelay = time.Second + } + + if c.Metrics.Path == "" { + c.Metrics.Path = "/metrics" + } + if c.Metrics.Port == 0 { + c.Metrics.Port = 9090 + } + + // 数据库默认值 + if c.Database.Type == "" { + c.Database.Type = "sqlite" + } + if c.Database.Type == "sqlite" && c.Database.DSN == "" { + c.Database.DSN = "data/s3-balance.db" + } + if c.Database.MaxOpenConns == 0 { + c.Database.MaxOpenConns = 25 + } + if c.Database.MaxIdleConns == 0 { + c.Database.MaxIdleConns = 5 + } + if c.Database.ConnMaxLifetime == 0 { + c.Database.ConnMaxLifetime = 300 + } + if c.Database.LogLevel == "" { + c.Database.LogLevel = "warn" + } + + // S3 API默认值 + if c.S3API.AccessKey == "" { + c.S3API.AccessKey = "AKIAIOSFODNN7EXAMPLE" + } + if c.S3API.SecretKey == "" { + c.S3API.SecretKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + } + // 默认使用预签名模式(非代理模式) + c.S3API.ProxyMode = false + c.S3API.AuthRequired = false +} + +// ParseMaxSize 解析最大容量字符串为字节 +func (bc *BucketConfig) ParseMaxSize() error { + if bc.MaxSize == "" { + bc.MaxSizeBytes = 0 // 无限制 + return nil + } + + var size int64 + var unit string + _, err := fmt.Sscanf(bc.MaxSize, "%d%s", &size, &unit) + if err != nil { + return fmt.Errorf("invalid size format: %s", bc.MaxSize) + } + + switch unit { + case "B", "b": + bc.MaxSizeBytes = size + case "KB", "kb", "K", "k": + bc.MaxSizeBytes = size * 1024 + case "MB", "mb", "M", "m": + bc.MaxSizeBytes = size * 1024 * 1024 + case "GB", "gb", "G", "g": + bc.MaxSizeBytes = size * 1024 * 1024 * 1024 + case "TB", "tb", "T", "t": + bc.MaxSizeBytes = size * 1024 * 1024 * 1024 * 1024 + default: + return fmt.Errorf("unsupported unit: %s", unit) + } + + return nil +} diff --git a/internal/database/database.go b/internal/database/database.go new file mode 100644 index 0000000..5436e6f --- /dev/null +++ b/internal/database/database.go @@ -0,0 +1,230 @@ +package database + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "time" + + "github.com/DullJZ/s3-balance/internal/config" + "github.com/DullJZ/s3-balance/internal/storage" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// DB 全局数据库连接 +var DB *gorm.DB + +// Initialize 初始化数据库连接 +func Initialize(cfg *config.DatabaseConfig) error { + var err error + + // 设置日志级别 + logLevel := getLogLevel(cfg.LogLevel) + + // GORM配置 + gormConfig := &gorm.Config{ + Logger: logger.Default.LogMode(logLevel), + NowFunc: func() time.Time { + return time.Now().Local() + }, + QueryFields: true, + } + + // 根据数据库类型创建连接 + switch cfg.Type { + case "sqlite": + DB, err = connectSQLite(cfg.DSN, gormConfig) + case "mysql": + DB, err = connectMySQL(cfg.DSN, gormConfig) + case "postgres", "postgresql": + DB, err = connectPostgreSQL(cfg.DSN, gormConfig) + default: + return fmt.Errorf("unsupported database type: %s", cfg.Type) + } + + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + // 获取底层SQL数据库连接 + sqlDB, err := DB.DB() + if err != nil { + return fmt.Errorf("failed to get sql.DB: %w", err) + } + + // 设置连接池参数 + sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) + sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) + sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second) + + // 测试连接 + if err := sqlDB.Ping(); err != nil { + return fmt.Errorf("failed to ping database: %w", err) + } + + // 自动迁移 + if cfg.AutoMigrate { + if err := AutoMigrate(); err != nil { + return fmt.Errorf("failed to auto migrate: %w", err) + } + } + + log.Printf("Successfully connected to %s database", cfg.Type) + return nil +} + +// connectSQLite 连接SQLite数据库 +func connectSQLite(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) { + // 创建数据目录(如果不存在) + dir := filepath.Dir(dsn) + if dir != "" && dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create database directory: %w", err) + } + } + + // 添加SQLite特定参数 + if dsn != ":memory:" { + dsn = fmt.Sprintf("%s?_journal_mode=WAL&_timeout=5000&_synchronous=NORMAL&_cache_size=10000", dsn) + } + + return gorm.Open(sqlite.Open(dsn), gormConfig) +} + +// connectMySQL 连接MySQL数据库 +func connectMySQL(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) { + // MySQL DSN示例: user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local + + // 如果DSN中没有指定字符集,添加默认字符集 + if dsn != "" { + dsn = ensureMySQLParams(dsn) + } + + return gorm.Open(mysql.Open(dsn), gormConfig) +} + +// connectPostgreSQL 连接PostgreSQL数据库 +func connectPostgreSQL(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) { + // PostgreSQL DSN示例: host=localhost user=user password=password dbname=mydb port=5432 sslmode=disable TimeZone=Asia/Shanghai + + return gorm.Open(postgres.Open(dsn), gormConfig) +} + +// ensureMySQLParams 确保MySQL DSN包含必要的参数 +func ensureMySQLParams(dsn string) string { + params := map[string]string{ + "charset": "utf8mb4", + "parseTime": "True", + "loc": "Local", + } + + separator := "?" + if len(dsn) > 0 && dsn[len(dsn)-1] == '?' { + separator = "" + } else if contains(dsn, "?") { + separator = "&" + } + + for key, value := range params { + if !contains(dsn, key+"=") { + dsn = fmt.Sprintf("%s%s%s=%s", dsn, separator, key, value) + separator = "&" + } + } + + return dsn +} + +// contains 检查字符串是否包含子串 +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// getLogLevel 获取GORM日志级别 +func getLogLevel(level string) logger.LogLevel { + switch level { + case "silent": + return logger.Silent + case "error": + return logger.Error + case "warn", "warning": + return logger.Warn + case "info": + return logger.Info + default: + return logger.Warn + } +} + +// AutoMigrate 自动迁移数据库表 +func AutoMigrate() error { + models := []interface{}{ + &storage.Object{}, + &storage.BucketStats{}, + &storage.UploadSession{}, + &storage.AccessLog{}, + } + + for _, model := range models { + if err := DB.AutoMigrate(model); err != nil { + return fmt.Errorf("failed to migrate %T: %w", model, err) + } + } + + log.Println("Database migration completed successfully") + return nil +} + +// Close 关闭数据库连接 +func Close() error { + if DB != nil { + sqlDB, err := DB.DB() + if err != nil { + return err + } + return sqlDB.Close() + } + return nil +} + +// HealthCheck 健康检查 +func HealthCheck() error { + if DB == nil { + return fmt.Errorf("database not initialized") + } + + sqlDB, err := DB.DB() + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + return sqlDB.PingContext(ctx) +} + +// Transaction 执行事务 +func Transaction(fn func(*gorm.DB) error) error { + return DB.Transaction(fn) +} + +// GetDB 获取数据库连接 +func GetDB() *gorm.DB { + return DB +} diff --git a/internal/storage/models.go b/internal/storage/models.go new file mode 100644 index 0000000..cf89c68 --- /dev/null +++ b/internal/storage/models.go @@ -0,0 +1,163 @@ +package storage + +import ( + "database/sql/driver" + "encoding/json" + "time" + + "gorm.io/gorm" +) + +// Object 对象信息模型 +type Object struct { + ID uint `gorm:"primaryKey" json:"id"` + Key string `gorm:"uniqueIndex;size:512;not null" json:"key"` + BucketName string `gorm:"index;size:255;not null" json:"bucket_name"` + Size int64 `gorm:"not null;default:0" json:"size"` + Metadata JSON `gorm:"type:json" json:"metadata,omitempty"` + ContentType string `gorm:"size:128" json:"content_type,omitempty"` + ETag string `gorm:"size:128" json:"etag,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` +} + +// TableName 指定表名 +func (Object) TableName() string { + return "objects" +} + +// BucketStats 存储桶统计信息模型 +type BucketStats struct { + ID uint `gorm:"primaryKey" json:"id"` + BucketName string `gorm:"uniqueIndex;size:255;not null" json:"bucket_name"` + ObjectCount int64 `gorm:"not null;default:0" json:"object_count"` + TotalSize int64 `gorm:"not null;default:0" json:"total_size"` + LastCheckedAt time.Time `json:"last_checked_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TableName 指定表名 +func (BucketStats) TableName() string { + return "bucket_stats" +} + +// UploadSession 上传会话模型(用于跟踪分片上传) +type UploadSession struct { + ID uint `gorm:"primaryKey" json:"id"` + UploadID string `gorm:"uniqueIndex;size:255;not null" json:"upload_id"` + Key string `gorm:"index;size:512;not null" json:"key"` + BucketName string `gorm:"index;size:255;not null" json:"bucket_name"` + TotalParts int `gorm:"not null;default:0" json:"total_parts"` + CompletedParts int `gorm:"not null;default:0" json:"completed_parts"` + Size int64 `gorm:"not null;default:0" json:"size"` + Status string `gorm:"size:32;not null;default:'pending'" json:"status"` // pending, completed, aborted + ExpiresAt time.Time `gorm:"index" json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` +} + +// TableName 指定表名 +func (UploadSession) TableName() string { + return "upload_sessions" +} + +// AccessLog 访问日志模型 +type AccessLog struct { + ID uint `gorm:"primaryKey" json:"id"` + Action string `gorm:"index;size:32;not null" json:"action"` // upload, download, delete + Key string `gorm:"index;size:512;not null" json:"key"` + BucketName string `gorm:"index;size:255" json:"bucket_name"` + Size int64 `gorm:"default:0" json:"size"` + ClientIP string `gorm:"size:64" json:"client_ip"` + UserAgent string `gorm:"size:512" json:"user_agent"` + Success bool `gorm:"default:true" json:"success"` + ErrorMsg string `gorm:"type:text" json:"error_msg,omitempty"` + ResponseTime int64 `gorm:"default:0" json:"response_time"` // 响应时间(毫秒) + CreatedAt time.Time `gorm:"index" json:"created_at"` +} + +// TableName 指定表名 +func (AccessLog) TableName() string { + return "access_logs" +} + +// JSON 自定义JSON类型,用于存储元数据 +type JSON map[string]interface{} + +// Value 实现driver.Valuer接口 +func (j JSON) Value() (driver.Value, error) { + if j == nil { + return nil, nil + } + return json.Marshal(j) +} + +// Scan 实现sql.Scanner接口 +func (j *JSON) Scan(value interface{}) error { + if value == nil { + *j = make(map[string]interface{}) + return nil + } + + var data []byte + switch v := value.(type) { + case []byte: + data = v + case string: + data = []byte(v) + default: + data = []byte("{}") + } + + return json.Unmarshal(data, j) +} + +// BeforeCreate GORM钩子 - 创建前 +func (o *Object) BeforeCreate(tx *gorm.DB) error { + if o.Metadata == nil { + o.Metadata = make(JSON) + } + return nil +} + +// BeforeCreate GORM钩子 - 创建前设置过期时间 +func (u *UploadSession) BeforeCreate(tx *gorm.DB) error { + if u.ExpiresAt.IsZero() { + u.ExpiresAt = time.Now().Add(24 * time.Hour) // 默认24小时过期 + } + return nil +} + +// IsExpired 检查上传会话是否过期 +func (u *UploadSession) IsExpired() bool { + return time.Now().After(u.ExpiresAt) +} + +// ObjectFilter 对象查询过滤器 +type ObjectFilter struct { + Key string + BucketName string + Prefix string + MinSize int64 + MaxSize int64 + StartTime time.Time + EndTime time.Time + Limit int + Offset int +} + +// AccessLogFilter 访问日志查询过滤器 +type AccessLogFilter struct { + Action string + Key string + BucketName string + ClientIP string + Success *bool + StartTime time.Time + EndTime time.Time + Limit int + Offset int +} diff --git a/internal/storage/service.go b/internal/storage/service.go new file mode 100644 index 0000000..d6c2b22 --- /dev/null +++ b/internal/storage/service.go @@ -0,0 +1,350 @@ +package storage + +import ( + "fmt" + "time" + + "gorm.io/gorm" +) + +// Service 存储服务(管理对象元数据) +type Service struct { + db *gorm.DB +} + +// NewService 创建新的存储服务 +func NewService(db *gorm.DB) *Service { + return &Service{ + db: db, + } +} + +// RecordObject 记录对象信息 +func (s *Service) RecordObject(key, bucketName string, size int64, metadata map[string]string) error { + obj := &Object{ + Key: key, + BucketName: bucketName, + Size: size, + } + + if len(metadata) > 0 { + obj.Metadata = make(JSON) + for k, v := range metadata { + obj.Metadata[k] = v + } + } else { + obj.Metadata = make(JSON) + } + + // 使用 Upsert(更新或插入) + result := s.db.Where("key = ?", key).FirstOrCreate(&obj) + if result.Error != nil { + return fmt.Errorf("failed to record object: %w", result.Error) + } + + if result.RowsAffected == 0 { + // 对象已存在,更新它 + updates := map[string]interface{}{ + "bucket_name": bucketName, + "size": size, + "metadata": obj.Metadata, + "updated_at": time.Now(), + } + if err := s.db.Model(&Object{}).Where("key = ?", key).Updates(updates).Error; err != nil { + return fmt.Errorf("failed to update object: %w", err) + } + } + + // 更新存储桶统计 + s.updateBucketStats(bucketName) + + return nil +} + +// FindObjectBucket 查找对象所在的存储桶 +func (s *Service) FindObjectBucket(key string) (string, error) { + var obj Object + if err := s.db.Where("key = ?", key).First(&obj).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return "", fmt.Errorf("object not found: %s", key) + } + return "", fmt.Errorf("failed to find object: %w", err) + } + return obj.BucketName, nil +} + +// GetObjectInfo 获取对象信息 +func (s *Service) GetObjectInfo(key string) (*Object, error) { + var obj Object + if err := s.db.Where("key = ?", key).First(&obj).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, fmt.Errorf("object not found: %s", key) + } + return nil, fmt.Errorf("failed to get object info: %w", err) + } + return &obj, nil +} + +// DeleteObject 删除对象记录(软删除) +func (s *Service) DeleteObject(key string) error { + var obj Object + if err := s.db.Where("key = ?", key).First(&obj).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return fmt.Errorf("object not found: %s", key) + } + return fmt.Errorf("failed to find object: %w", err) + } + + bucketName := obj.BucketName + + // 软删除 + if err := s.db.Delete(&obj).Error; err != nil { + return fmt.Errorf("failed to delete object: %w", err) + } + + // 更新存储桶统计 + s.updateBucketStats(bucketName) + + return nil +} + +// ListObjects 列出对象(支持S3兼容的参数) +func (s *Service) ListObjects(bucketName, prefix, marker string, maxKeys int) ([]*Object, error) { + var objects []*Object + query := s.db.Model(&Object{}) + + // 按bucket过滤 + if bucketName != "" { + query = query.Where("bucket_name = ?", bucketName) + } + + // 前缀过滤 + if prefix != "" { + query = query.Where("key LIKE ?", prefix+"%") + } + + // Marker分页 + if marker != "" { + query = query.Where("key > ?", marker) + } + + // 限制返回数量 + if maxKeys > 0 { + query = query.Limit(maxKeys) + } + + // 按key字母顺序排序(S3标准) + if err := query.Order("key ASC").Find(&objects).Error; err != nil { + return nil, fmt.Errorf("failed to list objects: %w", err) + } + + return objects, nil +} + +// GetBucketObjects 获取特定存储桶的所有对象 +func (s *Service) GetBucketObjects(bucketName string) ([]*Object, error) { + var objects []*Object + if err := s.db.Where("bucket_name = ?", bucketName).Find(&objects).Error; err != nil { + return nil, fmt.Errorf("failed to get bucket objects: %w", err) + } + return objects, nil +} + +// GetTotalSize 获取所有对象的总大小 +func (s *Service) GetTotalSize() (int64, error) { + var total int64 + if err := s.db.Model(&Object{}).Select("COALESCE(SUM(size), 0)").Scan(&total).Error; err != nil { + return 0, fmt.Errorf("failed to get total size: %w", err) + } + return total, nil +} + +// GetBucketSize 获取特定存储桶的总大小 +func (s *Service) GetBucketSize(bucketName string) (int64, error) { + var total int64 + if err := s.db.Model(&Object{}). + Where("bucket_name = ?", bucketName). + Select("COALESCE(SUM(size), 0)"). + Scan(&total).Error; err != nil { + return 0, fmt.Errorf("failed to get bucket size: %w", err) + } + return total, nil +} + +// GetObjectCount 获取对象总数 +func (s *Service) GetObjectCount() (int64, error) { + var count int64 + if err := s.db.Model(&Object{}).Count(&count).Error; err != nil { + return 0, fmt.Errorf("failed to get object count: %w", err) + } + return count, nil +} + +// GetBucketObjectCount 获取特定存储桶的对象数 +func (s *Service) GetBucketObjectCount(bucketName string) (int64, error) { + var count int64 + if err := s.db.Model(&Object{}). + Where("bucket_name = ?", bucketName). + Count(&count).Error; err != nil { + return 0, fmt.Errorf("failed to get bucket object count: %w", err) + } + return count, nil +} + +// updateBucketStats 更新存储桶统计信息 +func (s *Service) updateBucketStats(bucketName string) error { + var stats BucketStats + + // 获取或创建统计记录 + result := s.db.Where("bucket_name = ?", bucketName).FirstOrCreate(&stats, BucketStats{ + BucketName: bucketName, + }) + if result.Error != nil { + return fmt.Errorf("failed to get bucket stats: %w", result.Error) + } + + // 计算新的统计数据 + var count int64 + var totalSize int64 + + s.db.Model(&Object{}). + Where("bucket_name = ?", bucketName). + Count(&count) + + s.db.Model(&Object{}). + Where("bucket_name = ?", bucketName). + Select("COALESCE(SUM(size), 0)"). + Scan(&totalSize) + + // 更新统计数据 + updates := map[string]interface{}{ + "object_count": count, + "total_size": totalSize, + "last_checked_at": time.Now(), + } + + if err := s.db.Model(&stats).Updates(updates).Error; err != nil { + return fmt.Errorf("failed to update bucket stats: %w", err) + } + + return nil +} + +// RecordUploadSession 记录上传会话 +func (s *Service) RecordUploadSession(uploadID, key, bucketName string, totalParts int, size int64) error { + session := &UploadSession{ + UploadID: uploadID, + Key: key, + BucketName: bucketName, + TotalParts: totalParts, + Size: size, + Status: "pending", + } + + if err := s.db.Create(session).Error; err != nil { + return fmt.Errorf("failed to record upload session: %w", err) + } + + return nil +} + +// GetUploadSession 获取上传会话 +func (s *Service) GetUploadSession(uploadID string) (*UploadSession, error) { + var session UploadSession + if err := s.db.Where("upload_id = ?", uploadID).First(&session).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, fmt.Errorf("upload session not found: %s", uploadID) + } + return nil, fmt.Errorf("failed to get upload session: %w", err) + } + return &session, nil +} + +// UpdateUploadSession 更新上传会话 +func (s *Service) UpdateUploadSession(uploadID string, completedParts int, status string) error { + updates := map[string]interface{}{ + "completed_parts": completedParts, + "status": status, + "updated_at": time.Now(), + } + + if err := s.db.Model(&UploadSession{}). + Where("upload_id = ?", uploadID). + Updates(updates).Error; err != nil { + return fmt.Errorf("failed to update upload session: %w", err) + } + + return nil +} + +// CleanExpiredSessions 清理过期的上传会话 +func (s *Service) CleanExpiredSessions() error { + if err := s.db.Where("expires_at < ? AND status = ?", time.Now(), "pending"). + Delete(&UploadSession{}).Error; err != nil { + return fmt.Errorf("failed to clean expired sessions: %w", err) + } + return nil +} + +// RecordAccessLog 记录访问日志 +func (s *Service) RecordAccessLog(action, key, bucketName, clientIP, userAgent string, size int64, success bool, errorMsg string, responseTime int64) error { + log := &AccessLog{ + Action: action, + Key: key, + BucketName: bucketName, + ClientIP: clientIP, + UserAgent: userAgent, + Size: size, + Success: success, + ErrorMsg: errorMsg, + ResponseTime: responseTime, + } + + if err := s.db.Create(log).Error; err != nil { + return fmt.Errorf("failed to record access log: %w", err) + } + + return nil +} + +// GetAccessLogs 获取访问日志 +func (s *Service) GetAccessLogs(filter *AccessLogFilter) ([]*AccessLog, error) { + query := s.db.Model(&AccessLog{}) + + if filter != nil { + if filter.Action != "" { + query = query.Where("action = ?", filter.Action) + } + if filter.Key != "" { + query = query.Where("key = ?", filter.Key) + } + if filter.BucketName != "" { + query = query.Where("bucket_name = ?", filter.BucketName) + } + if filter.ClientIP != "" { + query = query.Where("client_ip = ?", filter.ClientIP) + } + if filter.Success != nil { + query = query.Where("success = ?", *filter.Success) + } + if !filter.StartTime.IsZero() { + query = query.Where("created_at >= ?", filter.StartTime) + } + if !filter.EndTime.IsZero() { + query = query.Where("created_at <= ?", filter.EndTime) + } + if filter.Limit > 0 { + query = query.Limit(filter.Limit) + } + if filter.Offset > 0 { + query = query.Offset(filter.Offset) + } + } + + var logs []*AccessLog + if err := query.Order("created_at DESC").Find(&logs).Error; err != nil { + return nil, fmt.Errorf("failed to get access logs: %w", err) + } + + return logs, nil +} diff --git a/pkg/presigner/multipart.go b/pkg/presigner/multipart.go new file mode 100644 index 0000000..e87e800 --- /dev/null +++ b/pkg/presigner/multipart.go @@ -0,0 +1,85 @@ +package presigner + +import ( + "context" + "fmt" + + "github.com/DullJZ/s3-balance/internal/bucket" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +// CompletedPart 已完成的分片信息 +type CompletedPart struct { + PartNumber int32 `json:"part_number"` + ETag string `json:"etag"` +} + +// CompleteMultipartUpload 完成分片上传 +func CompleteMultipartUpload(ctx context.Context, bucket *bucket.BucketInfo, key, uploadID string, parts []CompletedPart) error { + // 转换为AWS SDK格式 + var completedParts []types.CompletedPart + for _, part := range parts { + completedParts = append(completedParts, types.CompletedPart{ + PartNumber: aws.Int32(part.PartNumber), + ETag: aws.String(part.ETag), + }) + } + + // 完成分片上传 + _, err := bucket.Client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: aws.String(uploadID), + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completedParts, + }, + }) + if err != nil { + return fmt.Errorf("failed to complete multipart upload: %w", err) + } + + return nil +} + +// AbortMultipartUpload 中止分片上传 +func AbortMultipartUpload(ctx context.Context, bucket *bucket.BucketInfo, key, uploadID string) error { + _, err := bucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: aws.String(uploadID), + }) + if err != nil { + return fmt.Errorf("failed to abort multipart upload: %w", err) + } + + return nil +} + +// ListParts 列出已上传的分片 +func ListParts(ctx context.Context, bucket *bucket.BucketInfo, key, uploadID string) ([]types.Part, error) { + var allParts []types.Part + var nextPartNumberMarker *string + + for { + output, err := bucket.Client.ListParts(ctx, &s3.ListPartsInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: aws.String(uploadID), + PartNumberMarker: nextPartNumberMarker, + }) + if err != nil { + return nil, fmt.Errorf("failed to list parts: %w", err) + } + + allParts = append(allParts, output.Parts...) + + if output.IsTruncated == nil || !*output.IsTruncated { + break + } + nextPartNumberMarker = output.NextPartNumberMarker + } + + return allParts, nil +} diff --git a/pkg/presigner/presigner.go b/pkg/presigner/presigner.go new file mode 100644 index 0000000..0819f24 --- /dev/null +++ b/pkg/presigner/presigner.go @@ -0,0 +1,221 @@ +package presigner + +import ( + "context" + "fmt" + "time" + + "github.com/DullJZ/s3-balance/internal/bucket" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// Presigner 预签名URL生成器 +type Presigner struct { + uploadExpiry time.Duration + downloadExpiry time.Duration +} + +// NewPresigner 创建新的预签名URL生成器 +func NewPresigner(uploadExpiry, downloadExpiry time.Duration) *Presigner { + // 设置默认值 + if uploadExpiry == 0 { + uploadExpiry = 15 * time.Minute + } + if downloadExpiry == 0 { + downloadExpiry = 60 * time.Minute + } + + return &Presigner{ + uploadExpiry: uploadExpiry, + downloadExpiry: downloadExpiry, + } +} + +// UploadURL 生成上传预签名URL +type UploadURL struct { + URL string `json:"url"` + Method string `json:"method"` + Headers map[string]string `json:"headers,omitempty"` + Expiry time.Time `json:"expiry"` + BucketName string `json:"bucket_name"` + Key string `json:"key"` +} + +// GenerateUploadURL 生成上传预签名URL +func (p *Presigner) GenerateUploadURL(ctx context.Context, bucket *bucket.BucketInfo, key string, contentType string, metadata map[string]string) (*UploadURL, error) { + presignClient := s3.NewPresignClient(bucket.Client) + + // 构建PutObject请求 + putObjectInput := &s3.PutObjectInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + } + + // 设置Content-Type + if contentType != "" { + putObjectInput.ContentType = aws.String(contentType) + } + + // 设置元数据 + if len(metadata) > 0 { + putObjectInput.Metadata = metadata + } + + // 生成预签名URL + presignRequest, err := presignClient.PresignPutObject(ctx, putObjectInput, func(opts *s3.PresignOptions) { + opts.Expires = p.uploadExpiry + }) + if err != nil { + return nil, fmt.Errorf("failed to generate upload presigned URL: %w", err) + } + + // 转换Headers为map[string]string + headers := make(map[string]string) + for k, v := range presignRequest.SignedHeader { + if len(v) > 0 { + headers[k] = v[0] + } + } + + return &UploadURL{ + URL: presignRequest.URL, + Method: presignRequest.Method, + Headers: headers, + Expiry: time.Now().Add(p.uploadExpiry), + BucketName: bucket.Config.Name, + Key: key, + }, nil +} + +// DownloadURL 生成下载预签名URL +type DownloadURL struct { + URL string `json:"url"` + Method string `json:"method"` + Expiry time.Time `json:"expiry"` + BucketName string `json:"bucket_name"` + Key string `json:"key"` +} + +// GenerateDownloadURL 生成下载预签名URL +func (p *Presigner) GenerateDownloadURL(ctx context.Context, bucket *bucket.BucketInfo, key string) (*DownloadURL, error) { + presignClient := s3.NewPresignClient(bucket.Client) + + // 构建GetObject请求 + getObjectInput := &s3.GetObjectInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + } + + // 生成预签名URL + presignRequest, err := presignClient.PresignGetObject(ctx, getObjectInput, func(opts *s3.PresignOptions) { + opts.Expires = p.downloadExpiry + }) + if err != nil { + return nil, fmt.Errorf("failed to generate download presigned URL: %w", err) + } + + return &DownloadURL{ + URL: presignRequest.URL, + Method: presignRequest.Method, + Expiry: time.Now().Add(p.downloadExpiry), + BucketName: bucket.Config.Name, + Key: key, + }, nil +} + +// DeleteURL 生成删除预签名URL +type DeleteURL struct { + URL string `json:"url"` + Method string `json:"method"` + Expiry time.Time `json:"expiry"` + BucketName string `json:"bucket_name"` + Key string `json:"key"` +} + +// GenerateDeleteURL 生成删除预签名URL +func (p *Presigner) GenerateDeleteURL(ctx context.Context, bucket *bucket.BucketInfo, key string) (*DeleteURL, error) { + presignClient := s3.NewPresignClient(bucket.Client) + + // 构建DeleteObject请求 + deleteObjectInput := &s3.DeleteObjectInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + } + + // 生成预签名URL + presignRequest, err := presignClient.PresignDeleteObject(ctx, deleteObjectInput, func(opts *s3.PresignOptions) { + opts.Expires = 5 * time.Minute // 删除操作的URL有效期较短 + }) + if err != nil { + return nil, fmt.Errorf("failed to generate delete presigned URL: %w", err) + } + + return &DeleteURL{ + URL: presignRequest.URL, + Method: presignRequest.Method, + Expiry: time.Now().Add(5 * time.Minute), + BucketName: bucket.Config.Name, + Key: key, + }, nil +} + +// MultipartUploadURLs 分片上传预签名URLs +type MultipartUploadURLs struct { + UploadID string `json:"upload_id"` + PartURLs map[int]string `json:"part_urls"` + BucketName string `json:"bucket_name"` + Key string `json:"key"` + Expiry time.Time `json:"expiry"` +} + +// GenerateMultipartUploadURLs 生成分片上传预签名URLs +func (p *Presigner) GenerateMultipartUploadURLs(ctx context.Context, bucket *bucket.BucketInfo, key string, partCount int) (*MultipartUploadURLs, error) { + // 初始化分片上传 + createResp, err := bucket.Client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + }) + if err != nil { + return nil, fmt.Errorf("failed to create multipart upload: %w", err) + } + + presignClient := s3.NewPresignClient(bucket.Client) + partURLs := make(map[int]string) + + // 为每个分片生成预签名URL + for i := 1; i <= partCount; i++ { + uploadPartInput := &s3.UploadPartInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: createResp.UploadId, + PartNumber: aws.Int32(int32(i)), + } + + presignRequest, err := presignClient.PresignUploadPart(ctx, uploadPartInput, func(opts *s3.PresignOptions) { + opts.Expires = p.uploadExpiry + }) + if err != nil { + // 如果失败,中止分片上传 + bucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: createResp.UploadId, + }) + return nil, fmt.Errorf("failed to generate part %d presigned URL: %w", i, err) + } + + partURLs[i] = presignRequest.URL + } + + // 注意:CompleteMultipartUpload 和 AbortMultipartUpload 需要在客户端直接调用 + // 因为它们需要提供额外的参数(如Parts列表),不适合预签名 + + return &MultipartUploadURLs{ + UploadID: *createResp.UploadId, + PartURLs: partURLs, + BucketName: bucket.Config.Name, + Key: key, + Expiry: time.Now().Add(p.uploadExpiry), + }, nil +} diff --git a/test_s3_compatibility.py b/test_s3_compatibility.py new file mode 100644 index 0000000..0c8044b --- /dev/null +++ b/test_s3_compatibility.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +""" +S3 Balance - S3兼容性测试脚本 +使用boto3 AWS SDK测试S3 Balance的S3兼容性 +""" + +import os +import sys +import time +import hashlib +import tempfile +from datetime import datetime + +try: + import boto3 + from botocore.client import Config +except ImportError: + print("请先安装boto3: pip install boto3") + sys.exit(1) + +# S3 Balance服务配置 +S3_BALANCE_ENDPOINT = "http://localhost:8080" +ACCESS_KEY = "AKIAIOSFODNN7EXAMPLE" +SECRET_KEY = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + +# 测试配置 +TEST_BUCKET = "test-bucket-1" +TEST_KEY_PREFIX = f"test-{int(time.time())}" + +def create_s3_client(): + """创建S3客户端""" + return boto3.client( + 's3', + endpoint_url=S3_BALANCE_ENDPOINT, + aws_access_key_id=ACCESS_KEY, + aws_secret_access_key=SECRET_KEY, + config=Config( + signature_version='s3v4', + s3={'addressing_style': 'path'} + ), + region_name='us-east-1' + ) + +def test_list_buckets(s3_client): + """测试列出存储桶""" + print("\n1. 测试列出存储桶 (ListBuckets)...") + try: + response = s3_client.list_buckets() + buckets = response.get('Buckets', []) + print(f" ✓ 找到 {len(buckets)} 个存储桶") + for bucket in buckets: + print(f" - {bucket['Name']} (创建时间: {bucket['CreationDate']})") + return True + except Exception as e: + print(f" ✗ 失败: {e}") + return False + +def test_upload_object(s3_client): + """测试上传对象""" + print(f"\n2. 测试上传对象 (PutObject)...") + + # 创建测试文件 + test_data = b"Hello, S3 Balance! This is a test file." + test_key = f"{TEST_KEY_PREFIX}/test-upload.txt" + + try: + # 上传对象 + response = s3_client.put_object( + Bucket=TEST_BUCKET, + Key=test_key, + Body=test_data, + ContentType='text/plain', + Metadata={'test': 'true', 'timestamp': str(int(time.time()))} + ) + + etag = response.get('ETag', '').strip('"') + print(f" ✓ 成功上传对象: {test_key}") + print(f" ETag: {etag}") + return test_key + except Exception as e: + print(f" ✗ 失败: {e}") + return None + +def test_list_objects(s3_client): + """测试列出对象""" + print(f"\n3. 测试列出对象 (ListObjects)...") + + try: + response = s3_client.list_objects_v2( + Bucket=TEST_BUCKET, + Prefix=TEST_KEY_PREFIX, + MaxKeys=10 + ) + + objects = response.get('Contents', []) + print(f" ✓ 找到 {len(objects)} 个对象") + for obj in objects: + print(f" - {obj['Key']} (大小: {obj['Size']} bytes, 修改时间: {obj['LastModified']})") + return True + except Exception as e: + print(f" ✗ 失败: {e}") + return False + +def test_download_object(s3_client, key): + """测试下载对象""" + print(f"\n4. 测试下载对象 (GetObject)...") + + if not key: + print(" ⚠ 跳过: 没有可下载的对象") + return False + + try: + response = s3_client.get_object( + Bucket=TEST_BUCKET, + Key=key + ) + + data = response['Body'].read() + content_type = response.get('ContentType', '') + content_length = response.get('ContentLength', 0) + + print(f" ✓ 成功下载对象: {key}") + print(f" 内容类型: {content_type}") + print(f" 内容长度: {content_length} bytes") + print(f" 内容预览: {data[:50].decode('utf-8', errors='ignore')}...") + return True + except Exception as e: + print(f" ✗ 失败: {e}") + return False + +def test_head_object(s3_client, key): + """测试获取对象元数据""" + print(f"\n5. 测试获取对象元数据 (HeadObject)...") + + if not key: + print(" ⚠ 跳过: 没有可查询的对象") + return False + + try: + response = s3_client.head_object( + Bucket=TEST_BUCKET, + Key=key + ) + + print(f" ✓ 成功获取对象元数据: {key}") + print(f" 内容长度: {response.get('ContentLength', 0)} bytes") + print(f" 内容类型: {response.get('ContentType', '')}") + print(f" ETag: {response.get('ETag', '').strip('\"')}") + print(f" 最后修改: {response.get('LastModified', '')}") + return True + except Exception as e: + print(f" ✗ 失败: {e}") + return False + +def test_multipart_upload(s3_client): + """测试分片上传(大文件)""" + print(f"\n6. 测试分片上传 (Multipart Upload)...") + + # 创建一个5MB的测试文件 + test_key = f"{TEST_KEY_PREFIX}/test-multipart.bin" + part_size = 5 * 1024 * 1024 # 5MB per part + total_size = 10 * 1024 * 1024 # 10MB total + + try: + # 初始化分片上传 + response = s3_client.create_multipart_upload( + Bucket=TEST_BUCKET, + Key=test_key, + ContentType='application/octet-stream' + ) + upload_id = response['UploadId'] + print(f" ✓ 初始化分片上传,UploadId: {upload_id}") + + # 上传分片 + parts = [] + for i in range(2): # 上传2个5MB的分片 + part_number = i + 1 + part_data = os.urandom(part_size) # 生成随机数据 + + part_response = s3_client.upload_part( + Bucket=TEST_BUCKET, + Key=test_key, + PartNumber=part_number, + UploadId=upload_id, + Body=part_data + ) + + parts.append({ + 'ETag': part_response['ETag'], + 'PartNumber': part_number + }) + print(f" ✓ 上传分片 {part_number}/2 完成") + + # 完成分片上传 + s3_client.complete_multipart_upload( + Bucket=TEST_BUCKET, + Key=test_key, + UploadId=upload_id, + MultipartUpload={'Parts': parts} + ) + + print(f" ✓ 分片上传完成: {test_key}") + return test_key + except Exception as e: + print(f" ✗ 失败: {e}") + return None + +def test_delete_object(s3_client, key): + """测试删除对象""" + print(f"\n7. 测试删除对象 (DeleteObject)...") + + if not key: + print(" ⚠ 跳过: 没有可删除的对象") + return False + + try: + s3_client.delete_object( + Bucket=TEST_BUCKET, + Key=key + ) + + print(f" ✓ 成功删除对象: {key}") + return True + except Exception as e: + print(f" ✗ 失败: {e}") + return False + +def test_presigned_url(s3_client): + """测试预签名URL""" + print(f"\n8. 测试预签名URL...") + + test_key = f"{TEST_KEY_PREFIX}/test-presigned.txt" + + try: + # 生成上传预签名URL + upload_url = s3_client.generate_presigned_url( + 'put_object', + Params={'Bucket': TEST_BUCKET, 'Key': test_key}, + ExpiresIn=3600 + ) + print(f" ✓ 生成上传预签名URL") + print(f" URL: {upload_url[:80]}...") + + # 生成下载预签名URL + download_url = s3_client.generate_presigned_url( + 'get_object', + Params={'Bucket': TEST_BUCKET, 'Key': test_key}, + ExpiresIn=3600 + ) + print(f" ✓ 生成下载预签名URL") + print(f" URL: {download_url[:80]}...") + + return True + except Exception as e: + print(f" ✗ 失败: {e}") + return False + +def cleanup(s3_client): + """清理测试数据""" + print(f"\n清理测试数据...") + + try: + # 列出所有测试对象 + response = s3_client.list_objects_v2( + Bucket=TEST_BUCKET, + Prefix=TEST_KEY_PREFIX + ) + + objects = response.get('Contents', []) + if objects: + # 删除所有测试对象 + delete_objects = [{'Key': obj['Key']} for obj in objects] + s3_client.delete_objects( + Bucket=TEST_BUCKET, + Delete={'Objects': delete_objects} + ) + print(f" ✓ 删除了 {len(objects)} 个测试对象") + else: + print(" ✓ 没有需要清理的对象") + + return True + except Exception as e: + print(f" ✗ 清理失败: {e}") + return False + +def main(): + """主测试函数""" + print("=" * 60) + print("S3 Balance - S3兼容性测试") + print("=" * 60) + print(f"端点: {S3_BALANCE_ENDPOINT}") + print(f"测试桶: {TEST_BUCKET}") + print(f"测试前缀: {TEST_KEY_PREFIX}") + + # 创建S3客户端 + s3_client = create_s3_client() + + # 执行测试 + results = [] + + # 基础测试 + results.append(("ListBuckets", test_list_buckets(s3_client))) + + # 对象操作测试 + uploaded_key = test_upload_object(s3_client) + results.append(("PutObject", uploaded_key is not None)) + + results.append(("ListObjects", test_list_objects(s3_client))) + results.append(("GetObject", test_download_object(s3_client, uploaded_key))) + results.append(("HeadObject", test_head_object(s3_client, uploaded_key))) + + # 高级功能测试 + # multipart_key = test_multipart_upload(s3_client) + # results.append(("MultipartUpload", multipart_key is not None)) + + results.append(("PresignedURL", test_presigned_url(s3_client))) + + # 删除测试 + results.append(("DeleteObject", test_delete_object(s3_client, uploaded_key))) + + # 清理 + cleanup(s3_client) + + # 打印测试结果摘要 + print("\n" + "=" * 60) + print("测试结果摘要") + print("=" * 60) + + passed = sum(1 for _, result in results if result) + total = len(results) + + for test_name, result in results: + status = "✓ 通过" if result else "✗ 失败" + print(f"{test_name:20} {status}") + + print("-" * 60) + print(f"总计: {passed}/{total} 测试通过") + + if passed == total: + print("\n🎉 所有测试通过!S3 Balance S3兼容性良好。") + else: + print(f"\n⚠️ 有 {total - passed} 个测试失败,请检查服务配置。") + + return 0 if passed == total else 1 + +if __name__ == "__main__": + sys.exit(main())