This commit is contained in:
DullJZ
2025-08-22 21:15:56 +08:00
commit 37b6adb6de
16 changed files with 3838 additions and 0 deletions

48
.gitignore vendored Normal file
View File

@@ -0,0 +1,48 @@
# Binaries
*.exe
*.dll
*.so
*.dylib
# Test binary
*.test
# Output of the go coverage tool
*.out
# Dependency directories
vendor/
# Configuration files (keep examples)
config/*.yaml
!config/*.example.yaml
# Database files
data/
*.db
*.db-shm
*.db-wal
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Logs
*.log
logs/
# Environment variables
.env
.env.local

182
cmd/s3-balance/main.go Normal file
View File

@@ -0,0 +1,182 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/DullJZ/s3-balance/internal/api"
"github.com/DullJZ/s3-balance/internal/balancer"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/DullJZ/s3-balance/internal/config"
"github.com/DullJZ/s3-balance/internal/database"
"github.com/DullJZ/s3-balance/internal/storage"
"github.com/DullJZ/s3-balance/pkg/presigner"
"github.com/gorilla/mux"
)
func main() {
// 解析命令行参数
var configFile string
flag.StringVar(&configFile, "config", "config/config.yaml", "Path to configuration file")
flag.Parse()
// 加载配置
cfg, err := config.Load(configFile)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
// 初始化数据库
if err := database.Initialize(&cfg.Database); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer database.Close()
// 创建存储桶管理器
bucketManager, err := bucket.NewManager(cfg)
if err != nil {
log.Fatalf("Failed to create bucket manager: %v", err)
}
// 启动存储桶管理器(健康检查和统计更新)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
bucketManager.Start(ctx)
// 创建负载均衡器
lb, err := balancer.NewBalancer(bucketManager, &cfg.Balancer)
if err != nil {
log.Fatalf("Failed to create balancer: %v", err)
}
// 创建预签名URL生成器
signer := presigner.NewPresigner(
15*time.Minute, // 上传URL有效期
60*time.Minute, // 下载URL有效期
)
// 创建存储服务
storageService := storage.NewService(database.GetDB())
// 创建S3兼容API处理器
s3Handler := api.NewS3Handler(
bucketManager,
lb,
signer,
storageService,
cfg.S3API.AccessKey,
cfg.S3API.SecretKey,
)
// 设置路由
router := mux.NewRouter()
// 运行在S3兼容模式
log.Println("Running in S3-compatible mode")
s3Handler.RegisterS3Routes(router)
// 添加CORS中间件
router.Use(corsMiddleware)
// 添加日志中间件
router.Use(loggingMiddleware)
// 创建HTTP服务器
srv := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
}
// 启动服务器
go func() {
log.Printf("Starting S3 Balance Service on %s", srv.Addr)
log.Printf("Load balancing strategy: %s", cfg.Balancer.Strategy)
log.Printf("Managed buckets: %d", len(cfg.Buckets))
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Failed to start server: %v", err)
}
}()
// 等待中断信号
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
// 优雅关闭
log.Println("Shutting down server...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("Server shutdown error: %v", err)
}
// 停止存储桶管理器
bucketManager.Stop()
cancel()
log.Println("Server stopped")
}
// CORS中间件
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// 日志中间件
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装ResponseWriter以捕获状态码
wrapped := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
next.ServeHTTP(wrapped, r)
log.Printf(
"[%s] %s %s %d %v",
r.RemoteAddr,
r.Method,
r.RequestURI,
wrapped.statusCode,
time.Since(start),
)
})
}
// responseWriter 包装器用于捕获状态码
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}

129
config/config.example.yaml Normal file
View File

@@ -0,0 +1,129 @@
# S3 Balance Service Configuration Example
# 服务器配置
server:
host: "0.0.0.0"
port: 8080
read_timeout: 30s
write_timeout: 30s
idle_timeout: 60s
# 数据库配置
database:
# 数据库类型: sqlite, mysql, postgres
type: "sqlite"
# 数据源名称 (DSN)
# SQLite 示例:
dsn: "data/s3-balance.db"
# MySQL 示例:
# dsn: "user:password@tcp(localhost:3306)/s3balance?charset=utf8mb4&parseTime=True&loc=Local"
# PostgreSQL 示例:
# dsn: "host=localhost user=postgres password=password dbname=s3balance port=5432 sslmode=disable TimeZone=Asia/Shanghai"
# 连接池配置
max_open_conns: 25 # 最大打开连接数
max_idle_conns: 5 # 最大空闲连接数
conn_max_lifetime: 300 # 连接最大生命周期(秒)
# 日志级别: silent, error, warn, info
log_level: "warn"
# 是否自动迁移数据库表
auto_migrate: true
# S3存储桶配置
buckets:
# 第一个存储桶 - AWS S3
- name: "my-bucket-1"
endpoint: "" # 留空使用默认AWS端点
region: "us-east-1"
access_key_id: "YOUR_AWS_ACCESS_KEY_ID"
secret_access_key: "YOUR_AWS_SECRET_ACCESS_KEY"
max_size: "10GB" # 最大容量限制
weight: 10 # 权重(用于加权负载均衡)
enabled: true
use_ssl: true
path_style: false # AWS S3使用虚拟主机风格
# 第二个存储桶 - MinIO
- name: "my-bucket-2"
endpoint: "http://localhost:9000"
region: "us-east-1"
access_key_id: "minioadmin"
secret_access_key: "minioadmin"
max_size: "5GB"
weight: 5
enabled: true
use_ssl: false
path_style: true # MinIO通常使用路径风格
# 第三个存储桶 - 阿里云OSS兼容S3
- name: "my-bucket-3"
endpoint: "https://oss-cn-hangzhou.aliyuncs.com"
region: "cn-hangzhou"
access_key_id: "YOUR_ALIYUN_ACCESS_KEY_ID"
secret_access_key: "YOUR_ALIYUN_SECRET_ACCESS_KEY"
max_size: "20GB"
weight: 15
enabled: true
use_ssl: true
path_style: false
# 备用存储桶(可以禁用)
- name: "backup-bucket"
endpoint: "http://backup-s3.example.com"
region: "us-west-2"
access_key_id: "BACKUP_ACCESS_KEY"
secret_access_key: "BACKUP_SECRET_KEY"
max_size: "100GB"
weight: 1
enabled: false # 当前禁用
use_ssl: false
path_style: true
# 负载均衡配置
balancer:
# 负载均衡策略,可选值:
# - "round-robin": 轮询
# - "least-space": 选择剩余空间最多的存储桶
# - "weighted": 基于权重的随机选择
# - "consistent-hash": 一致性哈希相同的key总是选择相同的存储桶
strategy: "least-space"
# 健康检查周期
health_check_period: 30s
# 统计信息更新周期
update_stats_period: 60s
# 重试配置
retry_attempts: 3
retry_delay: 1s
# 监控指标配置
metrics:
enabled: true
path: "/metrics"
port: 9090
# S3兼容API配置
s3api:
# 客户端连接用的Access Key
access_key: "AKIAIOSFODNN7EXAMPLE"
# 客户端连接用的Secret Key
secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
# 是否使用虚拟主机模式
virtual_host: false
# 工作模式:
# false (默认):预签名重定向模式,客户端直接与后端存储交互
# true代理模式数据通过S3 Balance服务器传输
proxy_mode: false
# 是否需要认证开发环境可设为false
auth_required: false

45
go.mod Normal file
View File

@@ -0,0 +1,45 @@
module github.com/DullJZ/s3-balance
go 1.24.5
require (
github.com/aws/aws-sdk-go-v2 v1.38.0
github.com/aws/aws-sdk-go-v2/config v1.31.1
github.com/aws/aws-sdk-go-v2/credentials v1.18.5
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.0
github.com/gorilla/mux v1.8.1
gopkg.in/yaml.v3 v3.0.1
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.28.1 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.37.1 // indirect
github.com/aws/smithy-go v1.22.5 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/text v0.21.0 // indirect
gorm.io/driver/mysql v1.6.0 // indirect
gorm.io/driver/postgres v1.6.0 // indirect
gorm.io/driver/sqlite v1.6.0 // indirect
gorm.io/gorm v1.30.1 // indirect
)

80
go.sum Normal file
View File

@@ -0,0 +1,80 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU=
github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
github.com/aws/aws-sdk-go-v2/config v1.31.1 h1:PSQn4ObaQLaHl6qjs+XYH2pkxyHzZlk1GgQDrKlRJ7I=
github.com/aws/aws-sdk-go-v2/config v1.31.1/go.mod h1:3UA8Gj+2nzpV8WBUF0b19onBfz0YMXDQyGEW0Ru1ntI=
github.com/aws/aws-sdk-go-v2/credentials v1.18.5 h1:DATc1xnpHUV8VgvtnVQul+zuCwK6vz7gtkbKEUZcuNI=
github.com/aws/aws-sdk-go-v2/credentials v1.18.5/go.mod h1:y7aigZzjm1jUZuCgOrlBng+VJrKkknY2Cl0JWxG7vHU=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.3 h1:ZV2XK2L3HBq9sCKQiQ/MdhZJppH/rH0vddEAamsHUIs=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.3/go.mod h1:b9F9tk2HdHpbf3xbN7rUZcfmJI26N6NcJu/8OsBFI/0=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.3 h1:3ZKmesYBaFX33czDl6mbrcHb6jeheg6LqjJhQdefhsY=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.3/go.mod h1:7ryVb78GLCnjq7cw45N6oUb9REl7/vNUwjvIqC5UgdY=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.3 h1:SE/e52dq9a05RuxzLcjT+S5ZpQobj3ie3UTaSf2NnZc=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.3/go.mod h1:zkpvBTsR020VVr8TOrwK2TrUW9pOir28sH5ECHpnAfo=
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.0 h1:egoDf+Geuuntmw79Mz6mk9gGmELCPzg5PFEABOHB+6Y=
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.0/go.mod h1:t9MDi29H+HDbkolTSQtbI0HP9DemAWQzUjmWC7LGMnE=
github.com/aws/aws-sdk-go-v2/service/sso v1.28.1 h1:YfsU8hHGvVT+c6Q8MUs8haDbFQajAImrB7yZ9XnPcBY=
github.com/aws/aws-sdk-go-v2/service/sso v1.28.1/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.1 h1:b4REsk5C0hooowAPmV8fS2haHb+HCyb5FKSKOZRBBfU=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.1/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs=
github.com/aws/aws-sdk-go-v2/service/sts v1.37.1 h1:ssCHKyNJqTnqRH4Vlf+jI0brtGQYBvzWwnATsOMk1mk=
github.com/aws/aws-sdk-go-v2/service/sts v1.37.1/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8=
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4=
gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=

392
internal/api/handler.go Normal file
View File

@@ -0,0 +1,392 @@
package api
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
"time"
"github.com/DullJZ/s3-balance/internal/balancer"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/DullJZ/s3-balance/internal/storage"
"github.com/DullJZ/s3-balance/pkg/presigner"
"github.com/gorilla/mux"
)
// Handler API处理器
type Handler struct {
bucketManager *bucket.Manager
balancer *balancer.Balancer
presigner *presigner.Presigner
storage *storage.Service
}
// NewHandler 创建新的API处理器
func NewHandler(
bucketManager *bucket.Manager,
balancer *balancer.Balancer,
presigner *presigner.Presigner,
storage *storage.Service,
) *Handler {
return &Handler{
bucketManager: bucketManager,
balancer: balancer,
presigner: presigner,
storage: storage,
}
}
// RegisterRoutes 注册路由
func (h *Handler) RegisterRoutes(router *mux.Router) {
// 健康检查
router.HandleFunc("/health", h.handleHealth).Methods("GET")
// 存储桶状态
router.HandleFunc("/api/v1/buckets", h.handleListBuckets).Methods("GET")
router.HandleFunc("/api/v1/buckets/{bucket}/stats", h.handleBucketStats).Methods("GET")
// 预签名URL生成
router.HandleFunc("/api/v1/presign/upload", h.handlePresignUpload).Methods("POST")
router.HandleFunc("/api/v1/presign/download", h.handlePresignDownload).Methods("POST")
router.HandleFunc("/api/v1/presign/delete", h.handlePresignDelete).Methods("POST")
router.HandleFunc("/api/v1/presign/multipart", h.handlePresignMultipart).Methods("POST")
// 对象操作(记录元数据)
router.HandleFunc("/api/v1/objects", h.handleListObjects).Methods("GET")
router.HandleFunc("/api/v1/objects/{key:.*}", h.handleGetObjectInfo).Methods("GET")
router.HandleFunc("/api/v1/objects/{key:.*}", h.handleDeleteObject).Methods("DELETE")
}
// 健康检查
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"status": "healthy",
"time": time.Now().Unix(),
}
h.sendJSON(w, http.StatusOK, response)
}
// 列出所有存储桶状态
func (h *Handler) handleListBuckets(w http.ResponseWriter, r *http.Request) {
buckets := h.bucketManager.GetAllBuckets()
var bucketList []map[string]interface{}
for _, b := range buckets {
bucketList = append(bucketList, map[string]interface{}{
"name": b.Config.Name,
"endpoint": b.Config.Endpoint,
"region": b.Config.Region,
"max_size": b.Config.MaxSize,
"max_size_bytes": b.Config.MaxSizeBytes,
"used_size": b.GetUsedSize(),
"available": b.IsAvailable(),
"weight": b.Config.Weight,
"enabled": b.Config.Enabled,
})
}
h.sendJSON(w, http.StatusOK, map[string]interface{}{
"buckets": bucketList,
"strategy": h.balancer.GetStrategy(),
})
}
// 获取单个存储桶统计
func (h *Handler) handleBucketStats(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
bucket, ok := h.bucketManager.GetBucket(bucketName)
if !ok {
h.sendError(w, http.StatusNotFound, "bucket not found")
return
}
stats := map[string]interface{}{
"name": bucket.Config.Name,
"max_size_bytes": bucket.Config.MaxSizeBytes,
"used_size": bucket.GetUsedSize(),
"available_space": bucket.GetAvailableSpace(),
"available": bucket.IsAvailable(),
"last_checked": bucket.LastChecked,
}
h.sendJSON(w, http.StatusOK, stats)
}
// PresignUploadRequest 上传预签名请求
type PresignUploadRequest struct {
Key string `json:"key"`
Size int64 `json:"size"`
ContentType string `json:"content_type,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// 生成上传预签名URL
func (h *Handler) handlePresignUpload(w http.ResponseWriter, r *http.Request) {
var req PresignUploadRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.sendError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.Key == "" {
h.sendError(w, http.StatusBadRequest, "key is required")
return
}
// 选择存储桶
bucket, err := h.balancer.SelectBucket(req.Key, req.Size)
if err != nil {
h.sendError(w, http.StatusServiceUnavailable, err.Error())
return
}
// 生成预签名URL
uploadURL, err := h.presigner.GenerateUploadURL(
context.Background(),
bucket,
req.Key,
req.ContentType,
req.Metadata,
)
if err != nil {
h.sendError(w, http.StatusInternalServerError, "failed to generate upload URL")
return
}
// 记录对象元数据
if err := h.storage.RecordObject(req.Key, bucket.Config.Name, req.Size, req.Metadata); err != nil {
log.Printf("Failed to record object metadata: %v", err)
}
// 更新存储桶使用量(预估)
bucket.UpdateUsedSize(req.Size)
h.sendJSON(w, http.StatusOK, uploadURL)
}
// PresignDownloadRequest 下载预签名请求
type PresignDownloadRequest struct {
Key string `json:"key"`
}
// 生成下载预签名URL
func (h *Handler) handlePresignDownload(w http.ResponseWriter, r *http.Request) {
var req PresignDownloadRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.sendError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.Key == "" {
h.sendError(w, http.StatusBadRequest, "key is required")
return
}
// 查找对象所在的存储桶
bucketName, err := h.storage.FindObjectBucket(req.Key)
if err != nil {
h.sendError(w, http.StatusNotFound, "object not found")
return
}
bucket, ok := h.bucketManager.GetBucket(bucketName)
if !ok {
h.sendError(w, http.StatusNotFound, "bucket not found")
return
}
// 生成预签名URL
downloadURL, err := h.presigner.GenerateDownloadURL(
context.Background(),
bucket,
req.Key,
)
if err != nil {
h.sendError(w, http.StatusInternalServerError, "failed to generate download URL")
return
}
h.sendJSON(w, http.StatusOK, downloadURL)
}
// PresignDeleteRequest 删除预签名请求
type PresignDeleteRequest struct {
Key string `json:"key"`
}
// 生成删除预签名URL
func (h *Handler) handlePresignDelete(w http.ResponseWriter, r *http.Request) {
var req PresignDeleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.sendError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.Key == "" {
h.sendError(w, http.StatusBadRequest, "key is required")
return
}
// 查找对象所在的存储桶
bucketName, err := h.storage.FindObjectBucket(req.Key)
if err != nil {
h.sendError(w, http.StatusNotFound, "object not found")
return
}
bucket, ok := h.bucketManager.GetBucket(bucketName)
if !ok {
h.sendError(w, http.StatusNotFound, "bucket not found")
return
}
// 生成预签名URL
deleteURL, err := h.presigner.GenerateDeleteURL(
context.Background(),
bucket,
req.Key,
)
if err != nil {
h.sendError(w, http.StatusInternalServerError, "failed to generate delete URL")
return
}
h.sendJSON(w, http.StatusOK, deleteURL)
}
// PresignMultipartRequest 分片上传预签名请求
type PresignMultipartRequest struct {
Key string `json:"key"`
PartCount int `json:"part_count"`
Size int64 `json:"size"`
}
// 生成分片上传预签名URLs
func (h *Handler) handlePresignMultipart(w http.ResponseWriter, r *http.Request) {
var req PresignMultipartRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.sendError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.Key == "" || req.PartCount <= 0 {
h.sendError(w, http.StatusBadRequest, "invalid parameters")
return
}
// 选择存储桶
bucket, err := h.balancer.SelectBucket(req.Key, req.Size)
if err != nil {
h.sendError(w, http.StatusServiceUnavailable, err.Error())
return
}
// 生成预签名URLs
multipartURLs, err := h.presigner.GenerateMultipartUploadURLs(
context.Background(),
bucket,
req.Key,
req.PartCount,
)
if err != nil {
h.sendError(w, http.StatusInternalServerError, "failed to generate multipart URLs")
return
}
// 记录对象元数据
if err := h.storage.RecordObject(req.Key, bucket.Config.Name, req.Size, nil); err != nil {
log.Printf("Failed to record object metadata: %v", err)
}
// 更新存储桶使用量(预估)
bucket.UpdateUsedSize(req.Size)
h.sendJSON(w, http.StatusOK, multipartURLs)
}
// 列出对象
func (h *Handler) handleListObjects(w http.ResponseWriter, r *http.Request) {
prefix := r.URL.Query().Get("prefix")
bucketName := r.URL.Query().Get("bucket")
marker := r.URL.Query().Get("marker")
limitStr := r.URL.Query().Get("limit")
limit := 100
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
// 调用更新后的ListObjects方法传入所有必需的参数
objects, err := h.storage.ListObjects(bucketName, prefix, marker, limit)
if err != nil {
h.sendError(w, http.StatusInternalServerError, "failed to list objects")
return
}
h.sendJSON(w, http.StatusOK, map[string]interface{}{
"objects": objects,
"count": len(objects),
})
}
// 获取对象信息
func (h *Handler) handleGetObjectInfo(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
key := vars["key"]
info, err := h.storage.GetObjectInfo(key)
if err != nil {
h.sendError(w, http.StatusNotFound, "object not found")
return
}
h.sendJSON(w, http.StatusOK, info)
}
// 删除对象(只删除元数据记录)
func (h *Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
key := vars["key"]
// 获取对象信息以更新存储桶使用量
info, err := h.storage.GetObjectInfo(key)
if err != nil {
h.sendError(w, http.StatusNotFound, "object not found")
return
}
// 更新存储桶使用量
if bucket, ok := h.bucketManager.GetBucket(info.BucketName); ok {
bucket.UpdateUsedSize(-info.Size)
}
// 删除元数据记录
if err := h.storage.DeleteObject(key); err != nil {
h.sendError(w, http.StatusInternalServerError, "failed to delete object")
return
}
h.sendJSON(w, http.StatusOK, map[string]string{
"message": "object deleted successfully",
})
}
// 发送JSON响应
func (h *Handler) sendJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
// 发送错误响应
func (h *Handler) sendError(w http.ResponseWriter, status int, message string) {
h.sendJSON(w, status, map[string]string{
"error": message,
})
}

780
internal/api/s3_handler.go Normal file
View File

@@ -0,0 +1,780 @@
package api
import (
"context"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/DullJZ/s3-balance/internal/balancer"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/DullJZ/s3-balance/internal/storage"
"github.com/DullJZ/s3-balance/pkg/presigner"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/gorilla/mux"
)
// S3Handler S3兼容的API处理器
type S3Handler struct {
bucketManager *bucket.Manager
balancer *balancer.Balancer
presigner *presigner.Presigner
storage *storage.Service
accessKey string
secretKey string
}
// NewS3Handler 创建新的S3兼容API处理器
func NewS3Handler(
bucketManager *bucket.Manager,
balancer *balancer.Balancer,
presigner *presigner.Presigner,
storage *storage.Service,
accessKey string,
secretKey string,
) *S3Handler {
return &S3Handler{
bucketManager: bucketManager,
balancer: balancer,
presigner: presigner,
storage: storage,
accessKey: accessKey,
secretKey: secretKey,
}
}
// RegisterS3Routes 注册S3兼容的路由
func (h *S3Handler) RegisterS3Routes(router *mux.Router) {
// Service operations
router.HandleFunc("/", h.handleListBuckets).Methods("GET")
// Bucket operations
router.HandleFunc("/{bucket}", h.handleBucketOperations).Methods("GET", "HEAD", "PUT", "DELETE")
// Object operations
router.HandleFunc("/{bucket}/{key:.*}", h.handleObjectOperations).Methods("GET", "HEAD", "PUT", "DELETE")
// Multipart upload operations
router.HandleFunc("/{bucket}/{key:.*}", h.handleMultipartUpload).Methods("POST").Queries("uploads", "")
router.HandleFunc("/{bucket}/{key:.*}", h.handleListMultipartUploads).Methods("GET").Queries("uploads", "")
router.HandleFunc("/{bucket}/{key:.*}", h.handleListMultipartParts).Methods("GET").Queries("uploadId", "")
router.HandleFunc("/{bucket}/{key:.*}", h.handleCompleteMultipartUpload).Methods("POST").Queries("uploadId", "")
router.HandleFunc("/{bucket}/{key:.*}", h.handleAbortMultipartUpload).Methods("DELETE").Queries("uploadId", "")
// 添加认证中间件
router.Use(h.s3AuthMiddleware)
}
// S3 XML响应结构体定义
type ListBucketsResult struct {
XMLName xml.Name `xml:"ListAllMyBucketsResult"`
Xmlns string `xml:"xmlns,attr"`
Owner Owner `xml:"Owner"`
Buckets Buckets `xml:"Buckets"`
}
type Owner struct {
ID string `xml:"ID"`
DisplayName string `xml:"DisplayName"`
}
type Buckets struct {
Bucket []BucketInfo `xml:"Bucket"`
}
type BucketInfo struct {
Name string `xml:"Name"`
CreationDate time.Time `xml:"CreationDate"`
}
type ListBucketResult struct {
XMLName xml.Name `xml:"ListBucketResult"`
Xmlns string `xml:"xmlns,attr"`
Name string `xml:"Name"`
Prefix string `xml:"Prefix"`
Marker string `xml:"Marker"`
MaxKeys int `xml:"MaxKeys"`
IsTruncated bool `xml:"IsTruncated"`
Contents []ObjectInfo `xml:"Contents"`
CommonPrefixes []CommonPrefix `xml:"CommonPrefixes,omitempty"`
}
type ObjectInfo struct {
Key string `xml:"Key"`
LastModified time.Time `xml:"LastModified"`
ETag string `xml:"ETag"`
Size int64 `xml:"Size"`
StorageClass string `xml:"StorageClass"`
Owner Owner `xml:"Owner"`
}
type CommonPrefix struct {
Prefix string `xml:"Prefix"`
}
type InitiateMultipartUploadResult struct {
XMLName xml.Name `xml:"InitiateMultipartUploadResult"`
Xmlns string `xml:"xmlns,attr"`
Bucket string `xml:"Bucket"`
Key string `xml:"Key"`
UploadID string `xml:"UploadId"`
}
type ListMultipartUploadsResult struct {
XMLName xml.Name `xml:"ListMultipartUploadsResult"`
Xmlns string `xml:"xmlns,attr"`
Bucket string `xml:"Bucket"`
KeyMarker string `xml:"KeyMarker"`
UploadIdMarker string `xml:"UploadIdMarker"`
NextKeyMarker string `xml:"NextKeyMarker"`
NextUploadIdMarker string `xml:"NextUploadIdMarker"`
MaxUploads int `xml:"MaxUploads"`
IsTruncated bool `xml:"IsTruncated"`
Uploads []Upload `xml:"Upload"`
CommonPrefixes []CommonPrefix `xml:"CommonPrefixes,omitempty"`
}
type Upload struct {
Key string `xml:"Key"`
UploadID string `xml:"UploadId"`
Initiator Owner `xml:"Initiator"`
Owner Owner `xml:"Owner"`
StorageClass string `xml:"StorageClass"`
Initiated time.Time `xml:"Initiated"`
}
type ListPartsResult struct {
XMLName xml.Name `xml:"ListPartsResult"`
Xmlns string `xml:"xmlns,attr"`
Bucket string `xml:"Bucket"`
Key string `xml:"Key"`
UploadID string `xml:"UploadId"`
PartNumberMarker int `xml:"PartNumberMarker"`
NextPartNumberMarker int `xml:"NextPartNumberMarker"`
MaxParts int `xml:"MaxParts"`
IsTruncated bool `xml:"IsTruncated"`
Parts []Part `xml:"Part"`
}
type Part struct {
PartNumber int `xml:"PartNumber"`
LastModified time.Time `xml:"LastModified"`
ETag string `xml:"ETag"`
Size int64 `xml:"Size"`
}
type CompleteMultipartUpload struct {
XMLName xml.Name `xml:"CompleteMultipartUpload"`
Parts []Part `xml:"Part"`
}
type CompleteMultipartUploadResult struct {
XMLName xml.Name `xml:"CompleteMultipartUploadResult"`
Xmlns string `xml:"xmlns,attr"`
Location string `xml:"Location"`
Bucket string `xml:"Bucket"`
Key string `xml:"Key"`
ETag string `xml:"ETag"`
}
type ErrorResponse struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
Resource string `xml:"Resource"`
RequestID string `xml:"RequestId"`
}
// s3AuthMiddleware S3认证中间件简化版
func (h *S3Handler) s3AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 简化的认证实现实际应该验证AWS Signature
// 这里只做基本的header检查
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
// 允许匿名访问(用于测试)
// 在生产环境中应该要求认证
}
next.ServeHTTP(w, r)
})
}
// handleListBuckets 处理列出所有存储桶请求
func (h *S3Handler) handleListBuckets(w http.ResponseWriter, r *http.Request) {
buckets := h.bucketManager.GetAllBuckets()
result := ListBucketsResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Owner: Owner{
ID: "s3-balance",
DisplayName: "S3 Balance Service",
},
Buckets: Buckets{
Bucket: make([]BucketInfo, 0, len(buckets)),
},
}
for _, b := range buckets {
if b.IsAvailable() {
result.Buckets.Bucket = append(result.Buckets.Bucket, BucketInfo{
Name: b.Config.Name,
CreationDate: time.Now().Add(-24 * time.Hour), // 模拟创建时间
})
}
}
h.sendXMLResponse(w, http.StatusOK, result)
}
// handleBucketOperations 处理存储桶相关操作
func (h *S3Handler) handleBucketOperations(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
switch r.Method {
case "GET":
h.handleListObjects(w, r, bucketName)
case "HEAD":
h.handleHeadBucket(w, r, bucketName)
case "PUT":
h.handleCreateBucket(w, r, bucketName)
case "DELETE":
h.handleDeleteBucket(w, r, bucketName)
}
}
// handleListObjects 列出存储桶中的对象
func (h *S3Handler) handleListObjects(w http.ResponseWriter, r *http.Request, bucketName string) {
// 检查bucket是否存在
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 解析查询参数
prefix := r.URL.Query().Get("prefix")
marker := r.URL.Query().Get("marker")
maxKeysStr := r.URL.Query().Get("max-keys")
delimiter := r.URL.Query().Get("delimiter")
maxKeys := 1000
if maxKeysStr != "" {
if mk, err := strconv.Atoi(maxKeysStr); err == nil {
maxKeys = mk
}
}
// 从存储中获取对象列表
objects, err := h.storage.ListObjects(bucketName, prefix, marker, maxKeys)
if err != nil {
h.sendS3Error(w, "InternalError", "Internal server error", bucketName)
return
}
result := ListBucketResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Name: bucketName,
Prefix: prefix,
Marker: marker,
MaxKeys: maxKeys,
IsTruncated: false, // 简化实现
Contents: make([]ObjectInfo, 0),
}
// 处理分隔符逻辑
if delimiter != "" {
// 简化的分隔符处理
result.CommonPrefixes = make([]CommonPrefix, 0)
}
for _, obj := range objects {
result.Contents = append(result.Contents, ObjectInfo{
Key: obj.Key,
LastModified: obj.UpdatedAt,
ETag: fmt.Sprintf("\"%x\"", obj.ID), // 简化的ETag
Size: obj.Size,
StorageClass: "STANDARD",
Owner: Owner{
ID: "s3-balance",
DisplayName: "S3 Balance Service",
},
})
}
h.sendXMLResponse(w, http.StatusOK, result)
}
// handleHeadBucket 检查存储桶是否存在
func (h *S3Handler) handleHeadBucket(w http.ResponseWriter, r *http.Request, bucketName string) {
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
w.WriteHeader(http.StatusNotFound)
return
}
w.WriteHeader(http.StatusOK)
}
// handleCreateBucket 创建存储桶(虚拟实现)
func (h *S3Handler) handleCreateBucket(w http.ResponseWriter, r *http.Request, bucketName string) {
// 在负载均衡场景下不真正创建bucket只返回成功
// 实际的bucket应该在配置中预先定义
w.Header().Set("Location", "/" + bucketName)
w.WriteHeader(http.StatusOK)
}
// handleDeleteBucket 删除存储桶(虚拟实现)
func (h *S3Handler) handleDeleteBucket(w http.ResponseWriter, r *http.Request, bucketName string) {
// 在负载均衡场景下不真正删除bucket
w.WriteHeader(http.StatusNoContent)
}
// handleObjectOperations 处理对象相关操作
func (h *S3Handler) handleObjectOperations(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
key := vars["key"]
switch r.Method {
case "GET":
h.handleGetObject(w, r, bucketName, key)
case "HEAD":
h.handleHeadObject(w, r, bucketName, key)
case "PUT":
h.handlePutObject(w, r, bucketName, key)
case "DELETE":
h.handleDeleteObject(w, r, bucketName, key)
}
}
// handleGetObject 获取对象默认使用预签名URL重定向
func (h *S3Handler) handleGetObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) {
// 查找对象所在的实际存储桶
actualBucketName, err := h.storage.FindObjectBucket(key)
if err != nil {
h.sendS3Error(w, "NoSuchKey", "The specified key does not exist", key)
return
}
bucket, ok := h.bucketManager.GetBucket(actualBucketName)
if !ok {
h.sendS3Error(w, "InternalError", "Internal server error", bucketName)
return
}
// 生成预签名下载URL
downloadInfo, err := h.presigner.GenerateDownloadURL(
context.Background(),
bucket,
key,
)
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to generate download URL", key)
return
}
// 默认使用预签名重定向模式,只有明确指定时才使用代理模式
if r.URL.Query().Get("proxy") == "true" {
// 代理模式:服务器下载内容并返回给客户端
resp, err := http.Get(downloadInfo.URL)
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to fetch object", key)
return
}
defer resp.Body.Close()
// 复制响应头
for k, v := range resp.Header {
w.Header()[k] = v
}
// 复制响应体
io.Copy(w, resp.Body)
} else {
// 重定向模式返回302重定向到预签名URL默认
http.Redirect(w, r, downloadInfo.URL, http.StatusFound)
}
}
// handleHeadObject 获取对象元数据
func (h *S3Handler) handleHeadObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) {
// 从存储中获取对象信息
obj, err := h.storage.GetObjectInfo(key)
if err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Length", strconv.FormatInt(obj.Size, 10))
w.Header().Set("Last-Modified", obj.UpdatedAt.Format(http.TimeFormat))
w.Header().Set("ETag", fmt.Sprintf("\"%x\"", obj.ID))
w.Header().Set("Content-Type", "application/octet-stream")
w.WriteHeader(http.StatusOK)
}
// handlePutObject 上传对象默认使用预签名URL重定向
func (h *S3Handler) handlePutObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) {
// 获取内容长度
contentLength := r.ContentLength
if contentLength < 0 {
h.sendS3Error(w, "MissingContentLength", "Content-Length header is required", key)
return
}
// 选择目标存储桶
targetBucket, err := h.balancer.SelectBucket(key, contentLength)
if err != nil {
h.sendS3Error(w, "InsufficientStorage", "No bucket has enough space", key)
return
}
// 生成预签名上传URL
uploadInfo, err := h.presigner.GenerateUploadURL(
context.Background(),
targetBucket,
key,
r.Header.Get("Content-Type"),
nil, // metadata
)
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to generate upload URL", key)
return
}
// 默认使用预签名重定向模式,只有明确指定时才使用代理模式
if r.URL.Query().Get("proxy") == "true" {
// 代理模式读取请求体并上传到预签名URL
// 创建新的请求
req, err := http.NewRequest(uploadInfo.Method, uploadInfo.URL, r.Body)
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to create upload request", key)
return
}
// 设置必要的头
req.ContentLength = contentLength
if ct := r.Header.Get("Content-Type"); ct != "" {
req.Header.Set("Content-Type", ct)
}
// 添加预签名URL所需的额外头
for k, v := range uploadInfo.Headers {
req.Header.Set(k, v)
}
// 执行上传
client := &http.Client{Timeout: 30 * time.Minute}
resp, err := client.Do(req)
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to upload object", key)
return
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
// 记录对象元数据
h.storage.RecordObject(key, targetBucket.Config.Name, contentLength, nil)
targetBucket.UpdateUsedSize(contentLength)
// 返回成功响应
w.Header().Set("ETag", fmt.Sprintf("\"%x\"", time.Now().UnixNano()))
w.WriteHeader(http.StatusOK)
} else {
h.sendS3Error(w, "InternalError", "Upload failed", key)
}
} else {
// 重定向模式返回307临时重定向让客户端直接上传默认
w.Header().Set("Location", uploadInfo.URL)
w.WriteHeader(http.StatusTemporaryRedirect)
}
}
// handleDeleteObject 删除对象
func (h *S3Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request, bucketName string, key string) {
// 查找对象所在的实际存储桶
actualBucketName, err := h.storage.FindObjectBucket(key)
if err != nil {
// 对象不存在S3规范要求返回204
w.WriteHeader(http.StatusNoContent)
return
}
bucket, ok := h.bucketManager.GetBucket(actualBucketName)
if !ok {
h.sendS3Error(w, "InternalError", "Internal server error", bucketName)
return
}
// 生成预签名删除URL
deleteInfo, err := h.presigner.GenerateDeleteURL(
context.Background(),
bucket,
key,
)
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to generate delete URL", key)
return
}
// 执行删除
req, _ := http.NewRequest("DELETE", deleteInfo.URL, nil)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to delete object", key)
return
}
defer resp.Body.Close()
// 从数据库中删除对象记录
h.storage.DeleteObject(key)
// S3规范要求删除操作总是返回204
w.WriteHeader(http.StatusNoContent)
}
// handleMultipartUpload 初始化分片上传
func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
key := vars["key"]
// 选择目标存储桶
targetBucket, err := h.balancer.SelectBucket(key, 0) // 分片上传时不检查空间
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to select bucket for upload", key)
return
}
// 初始化分片上传
ctx := context.Background()
createResp, err := targetBucket.Client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(targetBucket.Config.Name),
Key: aws.String(key),
})
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to initiate multipart upload", key)
return
}
result := InitiateMultipartUploadResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Bucket: targetBucket.Config.Name,
Key: key,
UploadID: *createResp.UploadId,
}
h.sendXMLResponse(w, http.StatusOK, result)
}
// handleListMultipartUploads 列出分片上传
func (h *S3Handler) handleListMultipartUploads(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
key := vars["key"]
// 检查bucket是否存在
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 简化实现:返回空列表
result := ListMultipartUploadsResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Bucket: bucketName,
KeyMarker: key,
MaxUploads: 1000,
IsTruncated: false,
Uploads: make([]Upload, 0),
}
h.sendXMLResponse(w, http.StatusOK, result)
}
// handleListMultipartParts 列出分片上传的分片
func (h *S3Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
key := vars["key"]
uploadID := r.URL.Query().Get("uploadId")
// 检查bucket是否存在
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 简化实现:返回空列表
result := ListPartsResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Bucket: bucketName,
Key: key,
UploadID: uploadID,
PartNumberMarker: 0,
MaxParts: 1000,
IsTruncated: false,
Parts: make([]Part, 0),
}
h.sendXMLResponse(w, http.StatusOK, result)
}
// handleCompleteMultipartUpload 完成分片上传
func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
key := vars["key"]
uploadID := r.URL.Query().Get("uploadId")
// 查找对象所在的实际存储桶简化实现使用配置的bucket
bucket, ok := h.bucketManager.GetBucket(bucketName)
if !ok {
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 解析请求体以获取分片列表
var completeReq CompleteMultipartUpload
body, _ := io.ReadAll(r.Body)
xml.Unmarshal(body, &completeReq)
// 完成分片上传
ctx := context.Background()
var parts []types.CompletedPart
for _, part := range completeReq.Parts {
parts = append(parts, types.CompletedPart{
ETag: aws.String(part.ETag),
PartNumber: aws.Int32(int32(part.PartNumber)),
})
}
completeResp, err := bucket.Client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
MultipartUpload: &types.CompletedMultipartUpload{
Parts: parts,
},
})
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to complete multipart upload", key)
return
}
result := CompleteMultipartUploadResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Location: "/" + bucket.Config.Name + "/" + key,
Bucket: bucket.Config.Name,
Key: key,
ETag: *completeResp.ETag,
}
// 记录对象元数据(简化:假设总大小)
h.storage.RecordObject(key, bucket.Config.Name, 0, nil)
h.sendXMLResponse(w, http.StatusOK, result)
}
// handleAbortMultipartUpload 中止分片上传
func (h *S3Handler) handleAbortMultipartUpload(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
key := vars["key"]
uploadID := r.URL.Query().Get("uploadId")
// 查找对象所在的实际存储桶简化实现使用配置的bucket
bucket, ok := h.bucketManager.GetBucket(bucketName)
if !ok {
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 中止分片上传
ctx := context.Background()
_, err := bucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
})
if err != nil {
h.sendS3Error(w, "InternalError", "Failed to abort multipart upload", key)
return
}
w.WriteHeader(http.StatusNoContent)
}
// sendXMLResponse 发送XML响应
func (h *S3Handler) sendXMLResponse(w http.ResponseWriter, statusCode int, data interface{}) {
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(statusCode)
encoder := xml.NewEncoder(w)
encoder.Indent("", " ")
// 写入XML声明
w.Write([]byte(xml.Header))
if err := encoder.Encode(data); err != nil {
// 如果编码失败,记录错误
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
// sendS3Error 发送S3错误响应
func (h *S3Handler) sendS3Error(w http.ResponseWriter, code string, message string, resource string) {
errorResp := ErrorResponse{
Code: code,
Message: message,
Resource: resource,
RequestID: fmt.Sprintf("%d", time.Now().UnixNano()),
}
statusCode := http.StatusBadRequest
switch code {
case "NoSuchBucket", "NoSuchKey":
statusCode = http.StatusNotFound
case "BucketAlreadyExists":
statusCode = http.StatusConflict
case "InvalidAccessKeyId", "SignatureDoesNotMatch":
statusCode = http.StatusForbidden
case "InternalError":
statusCode = http.StatusInternalServerError
case "InsufficientStorage":
statusCode = http.StatusInsufficientStorage
}
h.sendXMLResponse(w, statusCode, errorResp)
}
// 辅助函数解析S3路径
func parseS3Path(requestPath string) (bucket string, key string) {
requestPath = strings.TrimPrefix(requestPath, "/")
parts := strings.SplitN(requestPath, "/", 2)
if len(parts) > 0 {
bucket = parts[0]
}
if len(parts) > 1 {
key = parts[1]
}
return bucket, key
}
// 辅助函数URL编码/解码
func urlEncodePath(p string) string {
return strings.ReplaceAll(url.QueryEscape(p), "+", "%20")
}
func urlDecodePath(p string) string {
decoded, _ := url.QueryUnescape(p)
return decoded
}

View File

@@ -0,0 +1,264 @@
package balancer
import (
"crypto/md5"
"encoding/binary"
"fmt"
"math/rand"
"sort"
"sync"
"sync/atomic"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/DullJZ/s3-balance/internal/config"
)
// Strategy 负载均衡策略接口
type Strategy interface {
SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error)
Name() string
}
// Balancer 负载均衡器
type Balancer struct {
manager *bucket.Manager
strategy Strategy
config *config.BalancerConfig
}
// NewBalancer 创建新的负载均衡器
func NewBalancer(manager *bucket.Manager, cfg *config.BalancerConfig) (*Balancer, error) {
var strategy Strategy
switch cfg.Strategy {
case "round-robin":
strategy = NewRoundRobinStrategy()
case "least-space":
strategy = NewLeastSpaceStrategy()
case "weighted":
strategy = NewWeightedStrategy()
case "consistent-hash":
strategy = NewConsistentHashStrategy()
default:
return nil, fmt.Errorf("unknown balancer strategy: %s", cfg.Strategy)
}
return &Balancer{
manager: manager,
strategy: strategy,
config: cfg,
}, nil
}
// SelectBucket 选择一个存储桶
func (b *Balancer) SelectBucket(key string, size int64) (*bucket.BucketInfo, error) {
// 获取所有可用的存储桶
buckets := b.manager.GetAvailableBuckets()
if len(buckets) == 0 {
return nil, fmt.Errorf("no available buckets")
}
// 过滤出有足够空间的存储桶
var availableBuckets []*bucket.BucketInfo
for _, bucket := range buckets {
if bucket.GetAvailableSpace() >= size {
availableBuckets = append(availableBuckets, bucket)
}
}
if len(availableBuckets) == 0 {
return nil, fmt.Errorf("no bucket has enough space for %d bytes", size)
}
// 使用策略选择存储桶
selected, err := b.strategy.SelectBucket(availableBuckets, key, size)
if err != nil {
return nil, err
}
return selected, nil
}
// GetStrategy 获取当前策略名称
func (b *Balancer) GetStrategy() string {
return b.strategy.Name()
}
// RoundRobinStrategy 轮询策略
type RoundRobinStrategy struct {
counter uint64
}
// NewRoundRobinStrategy 创建轮询策略
func NewRoundRobinStrategy() *RoundRobinStrategy {
return &RoundRobinStrategy{}
}
// SelectBucket 选择存储桶(轮询)
func (s *RoundRobinStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) {
if len(buckets) == 0 {
return nil, fmt.Errorf("no buckets available")
}
index := atomic.AddUint64(&s.counter, 1) % uint64(len(buckets))
return buckets[index], nil
}
// Name 返回策略名称
func (s *RoundRobinStrategy) Name() string {
return "round-robin"
}
// LeastSpaceStrategy 最少使用空间策略
type LeastSpaceStrategy struct{}
// NewLeastSpaceStrategy 创建最少使用空间策略
func NewLeastSpaceStrategy() *LeastSpaceStrategy {
return &LeastSpaceStrategy{}
}
// SelectBucket 选择存储桶(选择使用空间最少的)
func (s *LeastSpaceStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) {
if len(buckets) == 0 {
return nil, fmt.Errorf("no buckets available")
}
// 按可用空间排序(从大到小)
sort.Slice(buckets, func(i, j int) bool {
return buckets[i].GetAvailableSpace() > buckets[j].GetAvailableSpace()
})
return buckets[0], nil
}
// Name 返回策略名称
func (s *LeastSpaceStrategy) Name() string {
return "least-space"
}
// WeightedStrategy 加权策略
type WeightedStrategy struct {
mu sync.RWMutex
}
// NewWeightedStrategy 创建加权策略
func NewWeightedStrategy() *WeightedStrategy {
return &WeightedStrategy{}
}
// SelectBucket 选择存储桶(基于权重)
func (s *WeightedStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) {
if len(buckets) == 0 {
return nil, fmt.Errorf("no buckets available")
}
// 计算总权重
totalWeight := 0
for _, b := range buckets {
totalWeight += b.Config.Weight
}
if totalWeight == 0 {
// 如果所有权重都是0则随机选择
return buckets[rand.Intn(len(buckets))], nil
}
// 根据权重随机选择
randomWeight := rand.Intn(totalWeight)
currentWeight := 0
for _, b := range buckets {
currentWeight += b.Config.Weight
if randomWeight < currentWeight {
return b, nil
}
}
// 不应该到达这里,但为了安全返回最后一个
return buckets[len(buckets)-1], nil
}
// Name 返回策略名称
func (s *WeightedStrategy) Name() string {
return "weighted"
}
// ConsistentHashStrategy 一致性哈希策略
type ConsistentHashStrategy struct {
replicas int
ring map[uint32]*bucket.BucketInfo
nodes []uint32
mu sync.RWMutex
}
// NewConsistentHashStrategy 创建一致性哈希策略
func NewConsistentHashStrategy() *ConsistentHashStrategy {
return &ConsistentHashStrategy{
replicas: 100, // 每个节点的虚拟节点数
ring: make(map[uint32]*bucket.BucketInfo),
}
}
// SelectBucket 选择存储桶(基于一致性哈希)
func (s *ConsistentHashStrategy) SelectBucket(buckets []*bucket.BucketInfo, key string, size int64) (*bucket.BucketInfo, error) {
if len(buckets) == 0 {
return nil, fmt.Errorf("no buckets available")
}
// 更新哈希环
s.updateRing(buckets)
// 计算key的哈希值
hash := s.hash(key)
// 在环上找到第一个大于等于hash的节点
s.mu.RLock()
defer s.mu.RUnlock()
idx := sort.Search(len(s.nodes), func(i int) bool {
return s.nodes[i] >= hash
})
// 如果没找到,返回第一个节点(环形结构)
if idx == len(s.nodes) {
idx = 0
}
return s.ring[s.nodes[idx]], nil
}
// updateRing 更新哈希环
func (s *ConsistentHashStrategy) updateRing(buckets []*bucket.BucketInfo) {
s.mu.Lock()
defer s.mu.Unlock()
// 清空现有环
s.ring = make(map[uint32]*bucket.BucketInfo)
s.nodes = nil
// 为每个存储桶添加虚拟节点
for _, b := range buckets {
for i := 0; i < s.replicas; i++ {
virtualKey := fmt.Sprintf("%s-%d", b.Config.Name, i)
hash := s.hash(virtualKey)
s.ring[hash] = b
s.nodes = append(s.nodes, hash)
}
}
// 排序节点
sort.Slice(s.nodes, func(i, j int) bool {
return s.nodes[i] < s.nodes[j]
})
}
// hash 计算哈希值
func (s *ConsistentHashStrategy) hash(key string) uint32 {
h := md5.Sum([]byte(key))
return binary.BigEndian.Uint32(h[:4])
}
// Name 返回策略名称
func (s *ConsistentHashStrategy) Name() string {
return "consistent-hash"
}

310
internal/bucket/manager.go Normal file
View File

@@ -0,0 +1,310 @@
package bucket
import (
"context"
"fmt"
"sync"
"time"
"github.com/DullJZ/s3-balance/internal/config"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
)
// BucketInfo 存储桶信息
type BucketInfo struct {
Config config.BucketConfig
Client *s3.Client
UsedSize int64 // 已使用容量(字节)
Available bool // 是否可用
LastChecked time.Time // 最后检查时间
mu sync.RWMutex
}
// Manager 存储桶管理器
type Manager struct {
buckets map[string]*BucketInfo
mu sync.RWMutex
config *config.Config
stopChan chan struct{}
}
// NewManager 创建新的存储桶管理器
func NewManager(cfg *config.Config) (*Manager, error) {
m := &Manager{
buckets: make(map[string]*BucketInfo),
config: cfg,
stopChan: make(chan struct{}),
}
// 初始化所有存储桶客户端
for _, bucketCfg := range cfg.Buckets {
if !bucketCfg.Enabled {
continue
}
client, err := createS3Client(bucketCfg)
if err != nil {
return nil, fmt.Errorf("failed to create S3 client for bucket %s: %w", bucketCfg.Name, err)
}
info := &BucketInfo{
Config: bucketCfg,
Client: client,
Available: true,
LastChecked: time.Now(),
}
m.buckets[bucketCfg.Name] = info
}
return m, nil
}
// createS3Client 创建S3客户端
func createS3Client(bucketCfg config.BucketConfig) (*s3.Client, error) {
// 创建自定义端点解析器
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
if bucketCfg.Endpoint != "" {
return aws.Endpoint{
URL: bucketCfg.Endpoint,
SigningRegion: bucketCfg.Region,
HostnameImmutable: true,
}, nil
}
// 返回错误以使用默认解析器
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})
// 配置AWS SDK
cfg, err := awsconfig.LoadDefaultConfig(context.TODO(),
awsconfig.WithRegion(bucketCfg.Region),
awsconfig.WithEndpointResolverWithOptions(customResolver),
awsconfig.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(
bucketCfg.AccessKeyID,
bucketCfg.SecretAccessKey,
"",
),
),
)
if err != nil {
return nil, err
}
// 创建S3客户端
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = bucketCfg.PathStyle
})
return client, nil
}
// Start 启动管理器(健康检查和统计更新)
func (m *Manager) Start(ctx context.Context) {
// 启动健康检查
go m.healthCheckLoop(ctx)
// 启动统计更新
go m.statsUpdateLoop(ctx)
}
// Stop 停止管理器
func (m *Manager) Stop() {
close(m.stopChan)
}
// healthCheckLoop 健康检查循环
func (m *Manager) healthCheckLoop(ctx context.Context) {
ticker := time.NewTicker(m.config.Balancer.HealthCheckPeriod)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-m.stopChan:
return
case <-ticker.C:
m.checkAllBuckets(ctx)
}
}
}
// statsUpdateLoop 统计更新循环
func (m *Manager) statsUpdateLoop(ctx context.Context) {
ticker := time.NewTicker(m.config.Balancer.UpdateStatsPeriod)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-m.stopChan:
return
case <-ticker.C:
m.updateAllStats(ctx)
}
}
}
// checkAllBuckets 检查所有存储桶的健康状态
func (m *Manager) checkAllBuckets(ctx context.Context) {
m.mu.RLock()
buckets := make([]*BucketInfo, 0, len(m.buckets))
for _, b := range m.buckets {
buckets = append(buckets, b)
}
m.mu.RUnlock()
var wg sync.WaitGroup
for _, bucket := range buckets {
wg.Add(1)
go func(b *BucketInfo) {
defer wg.Done()
m.checkBucket(ctx, b)
}(bucket)
}
wg.Wait()
}
// checkBucket 检查单个存储桶
func (m *Manager) checkBucket(ctx context.Context, bucket *BucketInfo) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// 尝试列出存储桶(用于健康检查)
_, err := bucket.Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(bucket.Config.Name),
MaxKeys: aws.Int32(1),
})
bucket.mu.Lock()
bucket.Available = err == nil
bucket.LastChecked = time.Now()
bucket.mu.Unlock()
}
// updateAllStats 更新所有存储桶的统计信息
func (m *Manager) updateAllStats(ctx context.Context) {
m.mu.RLock()
buckets := make([]*BucketInfo, 0, len(m.buckets))
for _, b := range m.buckets {
buckets = append(buckets, b)
}
m.mu.RUnlock()
var wg sync.WaitGroup
for _, bucket := range buckets {
wg.Add(1)
go func(b *BucketInfo) {
defer wg.Done()
m.updateBucketStats(ctx, b)
}(bucket)
}
wg.Wait()
}
// updateBucketStats 更新单个存储桶的统计信息
func (m *Manager) updateBucketStats(ctx context.Context, bucket *BucketInfo) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
var totalSize int64
var continuationToken *string
for {
output, err := bucket.Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(bucket.Config.Name),
ContinuationToken: continuationToken,
})
if err != nil {
break
}
for _, obj := range output.Contents {
if obj.Size != nil {
totalSize += *obj.Size
}
}
if output.IsTruncated == nil || !*output.IsTruncated {
break
}
continuationToken = output.NextContinuationToken
}
bucket.mu.Lock()
bucket.UsedSize = totalSize
bucket.mu.Unlock()
}
// GetBucket 获取指定名称的存储桶
func (m *Manager) GetBucket(name string) (*BucketInfo, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
bucket, ok := m.buckets[name]
return bucket, ok
}
// GetAllBuckets 获取所有存储桶
func (m *Manager) GetAllBuckets() []*BucketInfo {
m.mu.RLock()
defer m.mu.RUnlock()
buckets := make([]*BucketInfo, 0, len(m.buckets))
for _, b := range m.buckets {
buckets = append(buckets, b)
}
return buckets
}
// GetAvailableBuckets 获取所有可用的存储桶
func (m *Manager) GetAvailableBuckets() []*BucketInfo {
m.mu.RLock()
defer m.mu.RUnlock()
var available []*BucketInfo
for _, b := range m.buckets {
b.mu.RLock()
if b.Available && (b.Config.MaxSizeBytes == 0 || b.UsedSize < b.Config.MaxSizeBytes) {
available = append(available, b)
}
b.mu.RUnlock()
}
return available
}
// GetAvailableSpace 获取存储桶的可用空间
func (b *BucketInfo) GetAvailableSpace() int64 {
b.mu.RLock()
defer b.mu.RUnlock()
if b.Config.MaxSizeBytes == 0 {
return 1 << 62 // 返回一个很大的数表示无限制
}
return b.Config.MaxSizeBytes - b.UsedSize
}
// IsAvailable 检查存储桶是否可用
func (b *BucketInfo) IsAvailable() bool {
b.mu.RLock()
defer b.mu.RUnlock()
return b.Available
}
// GetUsedSize 获取已使用容量
func (b *BucketInfo) GetUsedSize() int64 {
b.mu.RLock()
defer b.mu.RUnlock()
return b.UsedSize
}
// UpdateUsedSize 更新已使用容量
func (b *BucketInfo) UpdateUsedSize(delta int64) {
b.mu.Lock()
defer b.mu.Unlock()
b.UsedSize += delta
}

212
internal/config/config.go Normal file
View File

@@ -0,0 +1,212 @@
package config
import (
"fmt"
"os"
"time"
"gopkg.in/yaml.v3"
)
// Config 全局配置结构
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Buckets []BucketConfig `yaml:"buckets"`
Balancer BalancerConfig `yaml:"balancer"`
Metrics MetricsConfig `yaml:"metrics"`
S3API S3APIConfig `yaml:"s3api"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
IdleTimeout time.Duration `yaml:"idle_timeout"`
}
// BucketConfig S3存储桶配置
type BucketConfig struct {
Name string `yaml:"name"` // 桶名称
Endpoint string `yaml:"endpoint"` // S3端点
Region string `yaml:"region"` // 区域
AccessKeyID string `yaml:"access_key_id"` // 访问密钥ID
SecretAccessKey string `yaml:"secret_access_key"` // 访问密钥
MaxSize string `yaml:"max_size"` // 最大容量 (例如: "10GB")
MaxSizeBytes int64 `yaml:"-"` // 内部使用,字节为单位
Weight int `yaml:"weight"` // 权重 (用于负载均衡)
Enabled bool `yaml:"enabled"` // 是否启用
UseSSL bool `yaml:"use_ssl"` // 是否使用SSL
PathStyle bool `yaml:"path_style"` // 是否使用路径风格访问
}
// BalancerConfig 负载均衡配置
type BalancerConfig struct {
Strategy string `yaml:"strategy"` // 负载均衡策略: "round-robin", "least-space", "weighted", "consistent-hash"
HealthCheckPeriod time.Duration `yaml:"health_check_period"` // 健康检查周期
UpdateStatsPeriod time.Duration `yaml:"update_stats_period"` // 统计更新周期
RetryAttempts int `yaml:"retry_attempts"` // 重试次数
RetryDelay time.Duration `yaml:"retry_delay"` // 重试延迟
}
// MetricsConfig 监控指标配置
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
Port int `yaml:"port"`
}
// S3APIConfig S3兼容API配置
type S3APIConfig struct {
AccessKey string `yaml:"access_key"` // S3访问密钥ID
SecretKey string `yaml:"secret_key"` // S3秘密访问密钥
VirtualHost bool `yaml:"virtual_host"` // 是否使用虚拟主机模式
ProxyMode bool `yaml:"proxy_mode"` // 是否使用代理模式(而非重定向)
AuthRequired bool `yaml:"auth_required"` // 是否需要认证
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Type string `yaml:"type"` // 数据库类型: sqlite, mysql, postgres
DSN string `yaml:"dsn"` // 数据源名称
MaxOpenConns int `yaml:"max_open_conns"` // 最大打开连接数
MaxIdleConns int `yaml:"max_idle_conns"` // 最大空闲连接数
ConnMaxLifetime int `yaml:"conn_max_lifetime"` // 连接最大生命周期(秒)
LogLevel string `yaml:"log_level"` // 日志级别: silent, error, warn, info
AutoMigrate bool `yaml:"auto_migrate"` // 是否自动迁移
}
// Load 从文件加载配置
func Load(configPath string) (*Config, error) {
file, err := os.Open(configPath)
if err != nil {
return nil, fmt.Errorf("failed to open config file: %w", err)
}
defer file.Close()
var config Config
decoder := yaml.NewDecoder(file)
if err := decoder.Decode(&config); err != nil {
return nil, fmt.Errorf("failed to decode config: %w", err)
}
// 解析容量大小
for i := range config.Buckets {
if err := config.Buckets[i].ParseMaxSize(); err != nil {
return nil, fmt.Errorf("failed to parse max size for bucket %s: %w",
config.Buckets[i].Name, err)
}
}
// 设置默认值
config.SetDefaults()
return &config, nil
}
// SetDefaults 设置默认配置值
func (c *Config) SetDefaults() {
if c.Server.Host == "" {
c.Server.Host = "0.0.0.0"
}
if c.Server.Port == 0 {
c.Server.Port = 8080
}
if c.Server.ReadTimeout == 0 {
c.Server.ReadTimeout = 30 * time.Second
}
if c.Server.WriteTimeout == 0 {
c.Server.WriteTimeout = 30 * time.Second
}
if c.Server.IdleTimeout == 0 {
c.Server.IdleTimeout = 60 * time.Second
}
if c.Balancer.Strategy == "" {
c.Balancer.Strategy = "least-space"
}
if c.Balancer.HealthCheckPeriod == 0 {
c.Balancer.HealthCheckPeriod = 30 * time.Second
}
if c.Balancer.UpdateStatsPeriod == 0 {
c.Balancer.UpdateStatsPeriod = 60 * time.Second
}
if c.Balancer.RetryAttempts == 0 {
c.Balancer.RetryAttempts = 3
}
if c.Balancer.RetryDelay == 0 {
c.Balancer.RetryDelay = time.Second
}
if c.Metrics.Path == "" {
c.Metrics.Path = "/metrics"
}
if c.Metrics.Port == 0 {
c.Metrics.Port = 9090
}
// 数据库默认值
if c.Database.Type == "" {
c.Database.Type = "sqlite"
}
if c.Database.Type == "sqlite" && c.Database.DSN == "" {
c.Database.DSN = "data/s3-balance.db"
}
if c.Database.MaxOpenConns == 0 {
c.Database.MaxOpenConns = 25
}
if c.Database.MaxIdleConns == 0 {
c.Database.MaxIdleConns = 5
}
if c.Database.ConnMaxLifetime == 0 {
c.Database.ConnMaxLifetime = 300
}
if c.Database.LogLevel == "" {
c.Database.LogLevel = "warn"
}
// S3 API默认值
if c.S3API.AccessKey == "" {
c.S3API.AccessKey = "AKIAIOSFODNN7EXAMPLE"
}
if c.S3API.SecretKey == "" {
c.S3API.SecretKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
}
// 默认使用预签名模式(非代理模式)
c.S3API.ProxyMode = false
c.S3API.AuthRequired = false
}
// ParseMaxSize 解析最大容量字符串为字节
func (bc *BucketConfig) ParseMaxSize() error {
if bc.MaxSize == "" {
bc.MaxSizeBytes = 0 // 无限制
return nil
}
var size int64
var unit string
_, err := fmt.Sscanf(bc.MaxSize, "%d%s", &size, &unit)
if err != nil {
return fmt.Errorf("invalid size format: %s", bc.MaxSize)
}
switch unit {
case "B", "b":
bc.MaxSizeBytes = size
case "KB", "kb", "K", "k":
bc.MaxSizeBytes = size * 1024
case "MB", "mb", "M", "m":
bc.MaxSizeBytes = size * 1024 * 1024
case "GB", "gb", "G", "g":
bc.MaxSizeBytes = size * 1024 * 1024 * 1024
case "TB", "tb", "T", "t":
bc.MaxSizeBytes = size * 1024 * 1024 * 1024 * 1024
default:
return fmt.Errorf("unsupported unit: %s", unit)
}
return nil
}

View File

@@ -0,0 +1,230 @@
package database
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/DullJZ/s3-balance/internal/config"
"github.com/DullJZ/s3-balance/internal/storage"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// DB 全局数据库连接
var DB *gorm.DB
// Initialize 初始化数据库连接
func Initialize(cfg *config.DatabaseConfig) error {
var err error
// 设置日志级别
logLevel := getLogLevel(cfg.LogLevel)
// GORM配置
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logLevel),
NowFunc: func() time.Time {
return time.Now().Local()
},
QueryFields: true,
}
// 根据数据库类型创建连接
switch cfg.Type {
case "sqlite":
DB, err = connectSQLite(cfg.DSN, gormConfig)
case "mysql":
DB, err = connectMySQL(cfg.DSN, gormConfig)
case "postgres", "postgresql":
DB, err = connectPostgreSQL(cfg.DSN, gormConfig)
default:
return fmt.Errorf("unsupported database type: %s", cfg.Type)
}
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
// 获取底层SQL数据库连接
sqlDB, err := DB.DB()
if err != nil {
return fmt.Errorf("failed to get sql.DB: %w", err)
}
// 设置连接池参数
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err)
}
// 自动迁移
if cfg.AutoMigrate {
if err := AutoMigrate(); err != nil {
return fmt.Errorf("failed to auto migrate: %w", err)
}
}
log.Printf("Successfully connected to %s database", cfg.Type)
return nil
}
// connectSQLite 连接SQLite数据库
func connectSQLite(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) {
// 创建数据目录(如果不存在)
dir := filepath.Dir(dsn)
if dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
}
// 添加SQLite特定参数
if dsn != ":memory:" {
dsn = fmt.Sprintf("%s?_journal_mode=WAL&_timeout=5000&_synchronous=NORMAL&_cache_size=10000", dsn)
}
return gorm.Open(sqlite.Open(dsn), gormConfig)
}
// connectMySQL 连接MySQL数据库
func connectMySQL(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) {
// MySQL DSN示例: user:password@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local
// 如果DSN中没有指定字符集添加默认字符集
if dsn != "" {
dsn = ensureMySQLParams(dsn)
}
return gorm.Open(mysql.Open(dsn), gormConfig)
}
// connectPostgreSQL 连接PostgreSQL数据库
func connectPostgreSQL(dsn string, gormConfig *gorm.Config) (*gorm.DB, error) {
// PostgreSQL DSN示例: host=localhost user=user password=password dbname=mydb port=5432 sslmode=disable TimeZone=Asia/Shanghai
return gorm.Open(postgres.Open(dsn), gormConfig)
}
// ensureMySQLParams 确保MySQL DSN包含必要的参数
func ensureMySQLParams(dsn string) string {
params := map[string]string{
"charset": "utf8mb4",
"parseTime": "True",
"loc": "Local",
}
separator := "?"
if len(dsn) > 0 && dsn[len(dsn)-1] == '?' {
separator = ""
} else if contains(dsn, "?") {
separator = "&"
}
for key, value := range params {
if !contains(dsn, key+"=") {
dsn = fmt.Sprintf("%s%s%s=%s", dsn, separator, key, value)
separator = "&"
}
}
return dsn
}
// contains 检查字符串是否包含子串
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// getLogLevel 获取GORM日志级别
func getLogLevel(level string) logger.LogLevel {
switch level {
case "silent":
return logger.Silent
case "error":
return logger.Error
case "warn", "warning":
return logger.Warn
case "info":
return logger.Info
default:
return logger.Warn
}
}
// AutoMigrate 自动迁移数据库表
func AutoMigrate() error {
models := []interface{}{
&storage.Object{},
&storage.BucketStats{},
&storage.UploadSession{},
&storage.AccessLog{},
}
for _, model := range models {
if err := DB.AutoMigrate(model); err != nil {
return fmt.Errorf("failed to migrate %T: %w", model, err)
}
}
log.Println("Database migration completed successfully")
return nil
}
// Close 关闭数据库连接
func Close() error {
if DB != nil {
sqlDB, err := DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}
// HealthCheck 健康检查
func HealthCheck() error {
if DB == nil {
return fmt.Errorf("database not initialized")
}
sqlDB, err := DB.DB()
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return sqlDB.PingContext(ctx)
}
// Transaction 执行事务
func Transaction(fn func(*gorm.DB) error) error {
return DB.Transaction(fn)
}
// GetDB 获取数据库连接
func GetDB() *gorm.DB {
return DB
}

163
internal/storage/models.go Normal file
View File

@@ -0,0 +1,163 @@
package storage
import (
"database/sql/driver"
"encoding/json"
"time"
"gorm.io/gorm"
)
// Object 对象信息模型
type Object struct {
ID uint `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255;not null" json:"bucket_name"`
Size int64 `gorm:"not null;default:0" json:"size"`
Metadata JSON `gorm:"type:json" json:"metadata,omitempty"`
ContentType string `gorm:"size:128" json:"content_type,omitempty"`
ETag string `gorm:"size:128" json:"etag,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// TableName 指定表名
func (Object) TableName() string {
return "objects"
}
// BucketStats 存储桶统计信息模型
type BucketStats struct {
ID uint `gorm:"primaryKey" json:"id"`
BucketName string `gorm:"uniqueIndex;size:255;not null" json:"bucket_name"`
ObjectCount int64 `gorm:"not null;default:0" json:"object_count"`
TotalSize int64 `gorm:"not null;default:0" json:"total_size"`
LastCheckedAt time.Time `json:"last_checked_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName 指定表名
func (BucketStats) TableName() string {
return "bucket_stats"
}
// UploadSession 上传会话模型(用于跟踪分片上传)
type UploadSession struct {
ID uint `gorm:"primaryKey" json:"id"`
UploadID string `gorm:"uniqueIndex;size:255;not null" json:"upload_id"`
Key string `gorm:"index;size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255;not null" json:"bucket_name"`
TotalParts int `gorm:"not null;default:0" json:"total_parts"`
CompletedParts int `gorm:"not null;default:0" json:"completed_parts"`
Size int64 `gorm:"not null;default:0" json:"size"`
Status string `gorm:"size:32;not null;default:'pending'" json:"status"` // pending, completed, aborted
ExpiresAt time.Time `gorm:"index" json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// TableName 指定表名
func (UploadSession) TableName() string {
return "upload_sessions"
}
// AccessLog 访问日志模型
type AccessLog struct {
ID uint `gorm:"primaryKey" json:"id"`
Action string `gorm:"index;size:32;not null" json:"action"` // upload, download, delete
Key string `gorm:"index;size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255" json:"bucket_name"`
Size int64 `gorm:"default:0" json:"size"`
ClientIP string `gorm:"size:64" json:"client_ip"`
UserAgent string `gorm:"size:512" json:"user_agent"`
Success bool `gorm:"default:true" json:"success"`
ErrorMsg string `gorm:"type:text" json:"error_msg,omitempty"`
ResponseTime int64 `gorm:"default:0" json:"response_time"` // 响应时间(毫秒)
CreatedAt time.Time `gorm:"index" json:"created_at"`
}
// TableName 指定表名
func (AccessLog) TableName() string {
return "access_logs"
}
// JSON 自定义JSON类型用于存储元数据
type JSON map[string]interface{}
// Value 实现driver.Valuer接口
func (j JSON) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
// Scan 实现sql.Scanner接口
func (j *JSON) Scan(value interface{}) error {
if value == nil {
*j = make(map[string]interface{})
return nil
}
var data []byte
switch v := value.(type) {
case []byte:
data = v
case string:
data = []byte(v)
default:
data = []byte("{}")
}
return json.Unmarshal(data, j)
}
// BeforeCreate GORM钩子 - 创建前
func (o *Object) BeforeCreate(tx *gorm.DB) error {
if o.Metadata == nil {
o.Metadata = make(JSON)
}
return nil
}
// BeforeCreate GORM钩子 - 创建前设置过期时间
func (u *UploadSession) BeforeCreate(tx *gorm.DB) error {
if u.ExpiresAt.IsZero() {
u.ExpiresAt = time.Now().Add(24 * time.Hour) // 默认24小时过期
}
return nil
}
// IsExpired 检查上传会话是否过期
func (u *UploadSession) IsExpired() bool {
return time.Now().After(u.ExpiresAt)
}
// ObjectFilter 对象查询过滤器
type ObjectFilter struct {
Key string
BucketName string
Prefix string
MinSize int64
MaxSize int64
StartTime time.Time
EndTime time.Time
Limit int
Offset int
}
// AccessLogFilter 访问日志查询过滤器
type AccessLogFilter struct {
Action string
Key string
BucketName string
ClientIP string
Success *bool
StartTime time.Time
EndTime time.Time
Limit int
Offset int
}

350
internal/storage/service.go Normal file
View File

@@ -0,0 +1,350 @@
package storage
import (
"fmt"
"time"
"gorm.io/gorm"
)
// Service 存储服务(管理对象元数据)
type Service struct {
db *gorm.DB
}
// NewService 创建新的存储服务
func NewService(db *gorm.DB) *Service {
return &Service{
db: db,
}
}
// RecordObject 记录对象信息
func (s *Service) RecordObject(key, bucketName string, size int64, metadata map[string]string) error {
obj := &Object{
Key: key,
BucketName: bucketName,
Size: size,
}
if len(metadata) > 0 {
obj.Metadata = make(JSON)
for k, v := range metadata {
obj.Metadata[k] = v
}
} else {
obj.Metadata = make(JSON)
}
// 使用 Upsert更新或插入
result := s.db.Where("key = ?", key).FirstOrCreate(&obj)
if result.Error != nil {
return fmt.Errorf("failed to record object: %w", result.Error)
}
if result.RowsAffected == 0 {
// 对象已存在,更新它
updates := map[string]interface{}{
"bucket_name": bucketName,
"size": size,
"metadata": obj.Metadata,
"updated_at": time.Now(),
}
if err := s.db.Model(&Object{}).Where("key = ?", key).Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update object: %w", err)
}
}
// 更新存储桶统计
s.updateBucketStats(bucketName)
return nil
}
// FindObjectBucket 查找对象所在的存储桶
func (s *Service) FindObjectBucket(key string) (string, error) {
var obj Object
if err := s.db.Where("key = ?", key).First(&obj).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return "", fmt.Errorf("object not found: %s", key)
}
return "", fmt.Errorf("failed to find object: %w", err)
}
return obj.BucketName, nil
}
// GetObjectInfo 获取对象信息
func (s *Service) GetObjectInfo(key string) (*Object, error) {
var obj Object
if err := s.db.Where("key = ?", key).First(&obj).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("object not found: %s", key)
}
return nil, fmt.Errorf("failed to get object info: %w", err)
}
return &obj, nil
}
// DeleteObject 删除对象记录(软删除)
func (s *Service) DeleteObject(key string) error {
var obj Object
if err := s.db.Where("key = ?", key).First(&obj).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return fmt.Errorf("object not found: %s", key)
}
return fmt.Errorf("failed to find object: %w", err)
}
bucketName := obj.BucketName
// 软删除
if err := s.db.Delete(&obj).Error; err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
// 更新存储桶统计
s.updateBucketStats(bucketName)
return nil
}
// ListObjects 列出对象支持S3兼容的参数
func (s *Service) ListObjects(bucketName, prefix, marker string, maxKeys int) ([]*Object, error) {
var objects []*Object
query := s.db.Model(&Object{})
// 按bucket过滤
if bucketName != "" {
query = query.Where("bucket_name = ?", bucketName)
}
// 前缀过滤
if prefix != "" {
query = query.Where("key LIKE ?", prefix+"%")
}
// Marker分页
if marker != "" {
query = query.Where("key > ?", marker)
}
// 限制返回数量
if maxKeys > 0 {
query = query.Limit(maxKeys)
}
// 按key字母顺序排序S3标准
if err := query.Order("key ASC").Find(&objects).Error; err != nil {
return nil, fmt.Errorf("failed to list objects: %w", err)
}
return objects, nil
}
// GetBucketObjects 获取特定存储桶的所有对象
func (s *Service) GetBucketObjects(bucketName string) ([]*Object, error) {
var objects []*Object
if err := s.db.Where("bucket_name = ?", bucketName).Find(&objects).Error; err != nil {
return nil, fmt.Errorf("failed to get bucket objects: %w", err)
}
return objects, nil
}
// GetTotalSize 获取所有对象的总大小
func (s *Service) GetTotalSize() (int64, error) {
var total int64
if err := s.db.Model(&Object{}).Select("COALESCE(SUM(size), 0)").Scan(&total).Error; err != nil {
return 0, fmt.Errorf("failed to get total size: %w", err)
}
return total, nil
}
// GetBucketSize 获取特定存储桶的总大小
func (s *Service) GetBucketSize(bucketName string) (int64, error) {
var total int64
if err := s.db.Model(&Object{}).
Where("bucket_name = ?", bucketName).
Select("COALESCE(SUM(size), 0)").
Scan(&total).Error; err != nil {
return 0, fmt.Errorf("failed to get bucket size: %w", err)
}
return total, nil
}
// GetObjectCount 获取对象总数
func (s *Service) GetObjectCount() (int64, error) {
var count int64
if err := s.db.Model(&Object{}).Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to get object count: %w", err)
}
return count, nil
}
// GetBucketObjectCount 获取特定存储桶的对象数
func (s *Service) GetBucketObjectCount(bucketName string) (int64, error) {
var count int64
if err := s.db.Model(&Object{}).
Where("bucket_name = ?", bucketName).
Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to get bucket object count: %w", err)
}
return count, nil
}
// updateBucketStats 更新存储桶统计信息
func (s *Service) updateBucketStats(bucketName string) error {
var stats BucketStats
// 获取或创建统计记录
result := s.db.Where("bucket_name = ?", bucketName).FirstOrCreate(&stats, BucketStats{
BucketName: bucketName,
})
if result.Error != nil {
return fmt.Errorf("failed to get bucket stats: %w", result.Error)
}
// 计算新的统计数据
var count int64
var totalSize int64
s.db.Model(&Object{}).
Where("bucket_name = ?", bucketName).
Count(&count)
s.db.Model(&Object{}).
Where("bucket_name = ?", bucketName).
Select("COALESCE(SUM(size), 0)").
Scan(&totalSize)
// 更新统计数据
updates := map[string]interface{}{
"object_count": count,
"total_size": totalSize,
"last_checked_at": time.Now(),
}
if err := s.db.Model(&stats).Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update bucket stats: %w", err)
}
return nil
}
// RecordUploadSession 记录上传会话
func (s *Service) RecordUploadSession(uploadID, key, bucketName string, totalParts int, size int64) error {
session := &UploadSession{
UploadID: uploadID,
Key: key,
BucketName: bucketName,
TotalParts: totalParts,
Size: size,
Status: "pending",
}
if err := s.db.Create(session).Error; err != nil {
return fmt.Errorf("failed to record upload session: %w", err)
}
return nil
}
// GetUploadSession 获取上传会话
func (s *Service) GetUploadSession(uploadID string) (*UploadSession, error) {
var session UploadSession
if err := s.db.Where("upload_id = ?", uploadID).First(&session).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("upload session not found: %s", uploadID)
}
return nil, fmt.Errorf("failed to get upload session: %w", err)
}
return &session, nil
}
// UpdateUploadSession 更新上传会话
func (s *Service) UpdateUploadSession(uploadID string, completedParts int, status string) error {
updates := map[string]interface{}{
"completed_parts": completedParts,
"status": status,
"updated_at": time.Now(),
}
if err := s.db.Model(&UploadSession{}).
Where("upload_id = ?", uploadID).
Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update upload session: %w", err)
}
return nil
}
// CleanExpiredSessions 清理过期的上传会话
func (s *Service) CleanExpiredSessions() error {
if err := s.db.Where("expires_at < ? AND status = ?", time.Now(), "pending").
Delete(&UploadSession{}).Error; err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}
return nil
}
// RecordAccessLog 记录访问日志
func (s *Service) RecordAccessLog(action, key, bucketName, clientIP, userAgent string, size int64, success bool, errorMsg string, responseTime int64) error {
log := &AccessLog{
Action: action,
Key: key,
BucketName: bucketName,
ClientIP: clientIP,
UserAgent: userAgent,
Size: size,
Success: success,
ErrorMsg: errorMsg,
ResponseTime: responseTime,
}
if err := s.db.Create(log).Error; err != nil {
return fmt.Errorf("failed to record access log: %w", err)
}
return nil
}
// GetAccessLogs 获取访问日志
func (s *Service) GetAccessLogs(filter *AccessLogFilter) ([]*AccessLog, error) {
query := s.db.Model(&AccessLog{})
if filter != nil {
if filter.Action != "" {
query = query.Where("action = ?", filter.Action)
}
if filter.Key != "" {
query = query.Where("key = ?", filter.Key)
}
if filter.BucketName != "" {
query = query.Where("bucket_name = ?", filter.BucketName)
}
if filter.ClientIP != "" {
query = query.Where("client_ip = ?", filter.ClientIP)
}
if filter.Success != nil {
query = query.Where("success = ?", *filter.Success)
}
if !filter.StartTime.IsZero() {
query = query.Where("created_at >= ?", filter.StartTime)
}
if !filter.EndTime.IsZero() {
query = query.Where("created_at <= ?", filter.EndTime)
}
if filter.Limit > 0 {
query = query.Limit(filter.Limit)
}
if filter.Offset > 0 {
query = query.Offset(filter.Offset)
}
}
var logs []*AccessLog
if err := query.Order("created_at DESC").Find(&logs).Error; err != nil {
return nil, fmt.Errorf("failed to get access logs: %w", err)
}
return logs, nil
}

View File

@@ -0,0 +1,85 @@
package presigner
import (
"context"
"fmt"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
// CompletedPart 已完成的分片信息
type CompletedPart struct {
PartNumber int32 `json:"part_number"`
ETag string `json:"etag"`
}
// CompleteMultipartUpload 完成分片上传
func CompleteMultipartUpload(ctx context.Context, bucket *bucket.BucketInfo, key, uploadID string, parts []CompletedPart) error {
// 转换为AWS SDK格式
var completedParts []types.CompletedPart
for _, part := range parts {
completedParts = append(completedParts, types.CompletedPart{
PartNumber: aws.Int32(part.PartNumber),
ETag: aws.String(part.ETag),
})
}
// 完成分片上传
_, err := bucket.Client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
MultipartUpload: &types.CompletedMultipartUpload{
Parts: completedParts,
},
})
if err != nil {
return fmt.Errorf("failed to complete multipart upload: %w", err)
}
return nil
}
// AbortMultipartUpload 中止分片上传
func AbortMultipartUpload(ctx context.Context, bucket *bucket.BucketInfo, key, uploadID string) error {
_, err := bucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
})
if err != nil {
return fmt.Errorf("failed to abort multipart upload: %w", err)
}
return nil
}
// ListParts 列出已上传的分片
func ListParts(ctx context.Context, bucket *bucket.BucketInfo, key, uploadID string) ([]types.Part, error) {
var allParts []types.Part
var nextPartNumberMarker *string
for {
output, err := bucket.Client.ListParts(ctx, &s3.ListPartsInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
PartNumberMarker: nextPartNumberMarker,
})
if err != nil {
return nil, fmt.Errorf("failed to list parts: %w", err)
}
allParts = append(allParts, output.Parts...)
if output.IsTruncated == nil || !*output.IsTruncated {
break
}
nextPartNumberMarker = output.NextPartNumberMarker
}
return allParts, nil
}

221
pkg/presigner/presigner.go Normal file
View File

@@ -0,0 +1,221 @@
package presigner
import (
"context"
"fmt"
"time"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
)
// Presigner 预签名URL生成器
type Presigner struct {
uploadExpiry time.Duration
downloadExpiry time.Duration
}
// NewPresigner 创建新的预签名URL生成器
func NewPresigner(uploadExpiry, downloadExpiry time.Duration) *Presigner {
// 设置默认值
if uploadExpiry == 0 {
uploadExpiry = 15 * time.Minute
}
if downloadExpiry == 0 {
downloadExpiry = 60 * time.Minute
}
return &Presigner{
uploadExpiry: uploadExpiry,
downloadExpiry: downloadExpiry,
}
}
// UploadURL 生成上传预签名URL
type UploadURL struct {
URL string `json:"url"`
Method string `json:"method"`
Headers map[string]string `json:"headers,omitempty"`
Expiry time.Time `json:"expiry"`
BucketName string `json:"bucket_name"`
Key string `json:"key"`
}
// GenerateUploadURL 生成上传预签名URL
func (p *Presigner) GenerateUploadURL(ctx context.Context, bucket *bucket.BucketInfo, key string, contentType string, metadata map[string]string) (*UploadURL, error) {
presignClient := s3.NewPresignClient(bucket.Client)
// 构建PutObject请求
putObjectInput := &s3.PutObjectInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
}
// 设置Content-Type
if contentType != "" {
putObjectInput.ContentType = aws.String(contentType)
}
// 设置元数据
if len(metadata) > 0 {
putObjectInput.Metadata = metadata
}
// 生成预签名URL
presignRequest, err := presignClient.PresignPutObject(ctx, putObjectInput, func(opts *s3.PresignOptions) {
opts.Expires = p.uploadExpiry
})
if err != nil {
return nil, fmt.Errorf("failed to generate upload presigned URL: %w", err)
}
// 转换Headers为map[string]string
headers := make(map[string]string)
for k, v := range presignRequest.SignedHeader {
if len(v) > 0 {
headers[k] = v[0]
}
}
return &UploadURL{
URL: presignRequest.URL,
Method: presignRequest.Method,
Headers: headers,
Expiry: time.Now().Add(p.uploadExpiry),
BucketName: bucket.Config.Name,
Key: key,
}, nil
}
// DownloadURL 生成下载预签名URL
type DownloadURL struct {
URL string `json:"url"`
Method string `json:"method"`
Expiry time.Time `json:"expiry"`
BucketName string `json:"bucket_name"`
Key string `json:"key"`
}
// GenerateDownloadURL 生成下载预签名URL
func (p *Presigner) GenerateDownloadURL(ctx context.Context, bucket *bucket.BucketInfo, key string) (*DownloadURL, error) {
presignClient := s3.NewPresignClient(bucket.Client)
// 构建GetObject请求
getObjectInput := &s3.GetObjectInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
}
// 生成预签名URL
presignRequest, err := presignClient.PresignGetObject(ctx, getObjectInput, func(opts *s3.PresignOptions) {
opts.Expires = p.downloadExpiry
})
if err != nil {
return nil, fmt.Errorf("failed to generate download presigned URL: %w", err)
}
return &DownloadURL{
URL: presignRequest.URL,
Method: presignRequest.Method,
Expiry: time.Now().Add(p.downloadExpiry),
BucketName: bucket.Config.Name,
Key: key,
}, nil
}
// DeleteURL 生成删除预签名URL
type DeleteURL struct {
URL string `json:"url"`
Method string `json:"method"`
Expiry time.Time `json:"expiry"`
BucketName string `json:"bucket_name"`
Key string `json:"key"`
}
// GenerateDeleteURL 生成删除预签名URL
func (p *Presigner) GenerateDeleteURL(ctx context.Context, bucket *bucket.BucketInfo, key string) (*DeleteURL, error) {
presignClient := s3.NewPresignClient(bucket.Client)
// 构建DeleteObject请求
deleteObjectInput := &s3.DeleteObjectInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
}
// 生成预签名URL
presignRequest, err := presignClient.PresignDeleteObject(ctx, deleteObjectInput, func(opts *s3.PresignOptions) {
opts.Expires = 5 * time.Minute // 删除操作的URL有效期较短
})
if err != nil {
return nil, fmt.Errorf("failed to generate delete presigned URL: %w", err)
}
return &DeleteURL{
URL: presignRequest.URL,
Method: presignRequest.Method,
Expiry: time.Now().Add(5 * time.Minute),
BucketName: bucket.Config.Name,
Key: key,
}, nil
}
// MultipartUploadURLs 分片上传预签名URLs
type MultipartUploadURLs struct {
UploadID string `json:"upload_id"`
PartURLs map[int]string `json:"part_urls"`
BucketName string `json:"bucket_name"`
Key string `json:"key"`
Expiry time.Time `json:"expiry"`
}
// GenerateMultipartUploadURLs 生成分片上传预签名URLs
func (p *Presigner) GenerateMultipartUploadURLs(ctx context.Context, bucket *bucket.BucketInfo, key string, partCount int) (*MultipartUploadURLs, error) {
// 初始化分片上传
createResp, err := bucket.Client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
})
if err != nil {
return nil, fmt.Errorf("failed to create multipart upload: %w", err)
}
presignClient := s3.NewPresignClient(bucket.Client)
partURLs := make(map[int]string)
// 为每个分片生成预签名URL
for i := 1; i <= partCount; i++ {
uploadPartInput := &s3.UploadPartInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: createResp.UploadId,
PartNumber: aws.Int32(int32(i)),
}
presignRequest, err := presignClient.PresignUploadPart(ctx, uploadPartInput, func(opts *s3.PresignOptions) {
opts.Expires = p.uploadExpiry
})
if err != nil {
// 如果失败,中止分片上传
bucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: createResp.UploadId,
})
return nil, fmt.Errorf("failed to generate part %d presigned URL: %w", i, err)
}
partURLs[i] = presignRequest.URL
}
// 注意CompleteMultipartUpload 和 AbortMultipartUpload 需要在客户端直接调用
// 因为它们需要提供额外的参数如Parts列表不适合预签名
return &MultipartUploadURLs{
UploadID: *createResp.UploadId,
PartURLs: partURLs,
BucketName: bucket.Config.Name,
Key: key,
Expiry: time.Now().Add(p.uploadExpiry),
}, nil
}

347
test_s3_compatibility.py Normal file
View File

@@ -0,0 +1,347 @@
#!/usr/bin/env python3
"""
S3 Balance - S3兼容性测试脚本
使用boto3 AWS SDK测试S3 Balance的S3兼容性
"""
import os
import sys
import time
import hashlib
import tempfile
from datetime import datetime
try:
import boto3
from botocore.client import Config
except ImportError:
print("请先安装boto3: pip install boto3")
sys.exit(1)
# S3 Balance服务配置
S3_BALANCE_ENDPOINT = "http://localhost:8080"
ACCESS_KEY = "AKIAIOSFODNN7EXAMPLE"
SECRET_KEY = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
# 测试配置
TEST_BUCKET = "test-bucket-1"
TEST_KEY_PREFIX = f"test-{int(time.time())}"
def create_s3_client():
"""创建S3客户端"""
return boto3.client(
's3',
endpoint_url=S3_BALANCE_ENDPOINT,
aws_access_key_id=ACCESS_KEY,
aws_secret_access_key=SECRET_KEY,
config=Config(
signature_version='s3v4',
s3={'addressing_style': 'path'}
),
region_name='us-east-1'
)
def test_list_buckets(s3_client):
"""测试列出存储桶"""
print("\n1. 测试列出存储桶 (ListBuckets)...")
try:
response = s3_client.list_buckets()
buckets = response.get('Buckets', [])
print(f" ✓ 找到 {len(buckets)} 个存储桶")
for bucket in buckets:
print(f" - {bucket['Name']} (创建时间: {bucket['CreationDate']})")
return True
except Exception as e:
print(f" ✗ 失败: {e}")
return False
def test_upload_object(s3_client):
"""测试上传对象"""
print(f"\n2. 测试上传对象 (PutObject)...")
# 创建测试文件
test_data = b"Hello, S3 Balance! This is a test file."
test_key = f"{TEST_KEY_PREFIX}/test-upload.txt"
try:
# 上传对象
response = s3_client.put_object(
Bucket=TEST_BUCKET,
Key=test_key,
Body=test_data,
ContentType='text/plain',
Metadata={'test': 'true', 'timestamp': str(int(time.time()))}
)
etag = response.get('ETag', '').strip('"')
print(f" ✓ 成功上传对象: {test_key}")
print(f" ETag: {etag}")
return test_key
except Exception as e:
print(f" ✗ 失败: {e}")
return None
def test_list_objects(s3_client):
"""测试列出对象"""
print(f"\n3. 测试列出对象 (ListObjects)...")
try:
response = s3_client.list_objects_v2(
Bucket=TEST_BUCKET,
Prefix=TEST_KEY_PREFIX,
MaxKeys=10
)
objects = response.get('Contents', [])
print(f" ✓ 找到 {len(objects)} 个对象")
for obj in objects:
print(f" - {obj['Key']} (大小: {obj['Size']} bytes, 修改时间: {obj['LastModified']})")
return True
except Exception as e:
print(f" ✗ 失败: {e}")
return False
def test_download_object(s3_client, key):
"""测试下载对象"""
print(f"\n4. 测试下载对象 (GetObject)...")
if not key:
print(" ⚠ 跳过: 没有可下载的对象")
return False
try:
response = s3_client.get_object(
Bucket=TEST_BUCKET,
Key=key
)
data = response['Body'].read()
content_type = response.get('ContentType', '')
content_length = response.get('ContentLength', 0)
print(f" ✓ 成功下载对象: {key}")
print(f" 内容类型: {content_type}")
print(f" 内容长度: {content_length} bytes")
print(f" 内容预览: {data[:50].decode('utf-8', errors='ignore')}...")
return True
except Exception as e:
print(f" ✗ 失败: {e}")
return False
def test_head_object(s3_client, key):
"""测试获取对象元数据"""
print(f"\n5. 测试获取对象元数据 (HeadObject)...")
if not key:
print(" ⚠ 跳过: 没有可查询的对象")
return False
try:
response = s3_client.head_object(
Bucket=TEST_BUCKET,
Key=key
)
print(f" ✓ 成功获取对象元数据: {key}")
print(f" 内容长度: {response.get('ContentLength', 0)} bytes")
print(f" 内容类型: {response.get('ContentType', '')}")
print(f" ETag: {response.get('ETag', '').strip('\"')}")
print(f" 最后修改: {response.get('LastModified', '')}")
return True
except Exception as e:
print(f" ✗ 失败: {e}")
return False
def test_multipart_upload(s3_client):
"""测试分片上传(大文件)"""
print(f"\n6. 测试分片上传 (Multipart Upload)...")
# 创建一个5MB的测试文件
test_key = f"{TEST_KEY_PREFIX}/test-multipart.bin"
part_size = 5 * 1024 * 1024 # 5MB per part
total_size = 10 * 1024 * 1024 # 10MB total
try:
# 初始化分片上传
response = s3_client.create_multipart_upload(
Bucket=TEST_BUCKET,
Key=test_key,
ContentType='application/octet-stream'
)
upload_id = response['UploadId']
print(f" ✓ 初始化分片上传UploadId: {upload_id}")
# 上传分片
parts = []
for i in range(2): # 上传2个5MB的分片
part_number = i + 1
part_data = os.urandom(part_size) # 生成随机数据
part_response = s3_client.upload_part(
Bucket=TEST_BUCKET,
Key=test_key,
PartNumber=part_number,
UploadId=upload_id,
Body=part_data
)
parts.append({
'ETag': part_response['ETag'],
'PartNumber': part_number
})
print(f" ✓ 上传分片 {part_number}/2 完成")
# 完成分片上传
s3_client.complete_multipart_upload(
Bucket=TEST_BUCKET,
Key=test_key,
UploadId=upload_id,
MultipartUpload={'Parts': parts}
)
print(f" ✓ 分片上传完成: {test_key}")
return test_key
except Exception as e:
print(f" ✗ 失败: {e}")
return None
def test_delete_object(s3_client, key):
"""测试删除对象"""
print(f"\n7. 测试删除对象 (DeleteObject)...")
if not key:
print(" ⚠ 跳过: 没有可删除的对象")
return False
try:
s3_client.delete_object(
Bucket=TEST_BUCKET,
Key=key
)
print(f" ✓ 成功删除对象: {key}")
return True
except Exception as e:
print(f" ✗ 失败: {e}")
return False
def test_presigned_url(s3_client):
"""测试预签名URL"""
print(f"\n8. 测试预签名URL...")
test_key = f"{TEST_KEY_PREFIX}/test-presigned.txt"
try:
# 生成上传预签名URL
upload_url = s3_client.generate_presigned_url(
'put_object',
Params={'Bucket': TEST_BUCKET, 'Key': test_key},
ExpiresIn=3600
)
print(f" ✓ 生成上传预签名URL")
print(f" URL: {upload_url[:80]}...")
# 生成下载预签名URL
download_url = s3_client.generate_presigned_url(
'get_object',
Params={'Bucket': TEST_BUCKET, 'Key': test_key},
ExpiresIn=3600
)
print(f" ✓ 生成下载预签名URL")
print(f" URL: {download_url[:80]}...")
return True
except Exception as e:
print(f" ✗ 失败: {e}")
return False
def cleanup(s3_client):
"""清理测试数据"""
print(f"\n清理测试数据...")
try:
# 列出所有测试对象
response = s3_client.list_objects_v2(
Bucket=TEST_BUCKET,
Prefix=TEST_KEY_PREFIX
)
objects = response.get('Contents', [])
if objects:
# 删除所有测试对象
delete_objects = [{'Key': obj['Key']} for obj in objects]
s3_client.delete_objects(
Bucket=TEST_BUCKET,
Delete={'Objects': delete_objects}
)
print(f" ✓ 删除了 {len(objects)} 个测试对象")
else:
print(" ✓ 没有需要清理的对象")
return True
except Exception as e:
print(f" ✗ 清理失败: {e}")
return False
def main():
"""主测试函数"""
print("=" * 60)
print("S3 Balance - S3兼容性测试")
print("=" * 60)
print(f"端点: {S3_BALANCE_ENDPOINT}")
print(f"测试桶: {TEST_BUCKET}")
print(f"测试前缀: {TEST_KEY_PREFIX}")
# 创建S3客户端
s3_client = create_s3_client()
# 执行测试
results = []
# 基础测试
results.append(("ListBuckets", test_list_buckets(s3_client)))
# 对象操作测试
uploaded_key = test_upload_object(s3_client)
results.append(("PutObject", uploaded_key is not None))
results.append(("ListObjects", test_list_objects(s3_client)))
results.append(("GetObject", test_download_object(s3_client, uploaded_key)))
results.append(("HeadObject", test_head_object(s3_client, uploaded_key)))
# 高级功能测试
# multipart_key = test_multipart_upload(s3_client)
# results.append(("MultipartUpload", multipart_key is not None))
results.append(("PresignedURL", test_presigned_url(s3_client)))
# 删除测试
results.append(("DeleteObject", test_delete_object(s3_client, uploaded_key)))
# 清理
cleanup(s3_client)
# 打印测试结果摘要
print("\n" + "=" * 60)
print("测试结果摘要")
print("=" * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
for test_name, result in results:
status = "✓ 通过" if result else "✗ 失败"
print(f"{test_name:20} {status}")
print("-" * 60)
print(f"总计: {passed}/{total} 测试通过")
if passed == total:
print("\n🎉 所有测试通过S3 Balance S3兼容性良好。")
else:
print(f"\n⚠️ 有 {total - passed} 个测试失败,请检查服务配置。")
return 0 if passed == total else 1
if __name__ == "__main__":
sys.exit(main())