From 15a39aa632736db8ccb9d3dbcfb920012337b22a Mon Sep 17 00:00:00 2001 From: DullJZ <79080562+DullJZ@users.noreply.github.com> Date: Mon, 29 Sep 2025 23:55:35 +0800 Subject: [PATCH] middleware --- cmd/s3-balance/main.go | 32 +++++------ internal/api/auth_middleware.go | 57 ------------------- internal/api/s3_handler.go | 18 +++++- internal/api/virtual_host_middleware.go | 72 ------------------------ internal/middleware/auth.go | 70 ++++++++++++++++++++++++ internal/middleware/virtual_host.go | 73 +++++++++++++++++++++++++ 6 files changed, 174 insertions(+), 148 deletions(-) delete mode 100644 internal/api/auth_middleware.go delete mode 100644 internal/api/virtual_host_middleware.go create mode 100644 internal/middleware/auth.go create mode 100644 internal/middleware/virtual_host.go diff --git a/cmd/s3-balance/main.go b/cmd/s3-balance/main.go index 0050862..7018974 100644 --- a/cmd/s3-balance/main.go +++ b/cmd/s3-balance/main.go @@ -64,14 +64,14 @@ func main() { if err != nil { log.Fatalf("Failed to create balancer: %v", err) } - + // 设置指标服务 lb.SetMetrics(metricsService) // 创建预签名URL生成器 signer := presigner.NewPresigner( - 15*time.Minute, // 上传URL有效期 - 60*time.Minute, // 下载URL有效期 + 15*time.Minute, // 上传URL有效期 + 60*time.Minute, // 下载URL有效期 ) // 创建存储服务 @@ -113,13 +113,13 @@ func main() { // 设置路由 router := mux.NewRouter() - + // 添加指标端点 if cfg.Metrics.Enabled { router.Path(cfg.Metrics.Path).Handler(promhttp.Handler()) log.Printf("Metrics server enabled at %s", cfg.Metrics.Path) } - + // 运行在S3兼容模式 log.Println("Running in S3-compatible mode") s3Handler.RegisterS3Routes(router) @@ -144,7 +144,7 @@ func main() { log.Printf("Starting S3 Balance Service on %s", srv.Addr) log.Printf("Load balancing strategy: %s", cfg.Balancer.Strategy) log.Printf("Managed buckets: %d", len(cfg.Buckets)) - + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Failed to start server: %v", err) } @@ -157,7 +157,7 @@ func main() { // 优雅关闭 log.Println("Shutting down server...") - + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) defer shutdownCancel() @@ -178,12 +178,12 @@ func corsMiddleware(next http.Handler) http.Handler { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - + if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } - + next.ServeHTTP(w, r) }) } @@ -192,15 +192,15 @@ func corsMiddleware(next http.Handler) http.Handler { func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() - + // 包装ResponseWriter以捕获状态码 wrapped := &responseWriter{ ResponseWriter: w, statusCode: http.StatusOK, } - + next.ServeHTTP(wrapped, r) - + log.Printf( "[%s] %s %s %d %v", r.RemoteAddr, @@ -228,11 +228,11 @@ func startSessionCleaner(ctx context.Context, storageService *storage.Service) { go func() { // 初始延迟,避免启动时立即执行 time.Sleep(1 * time.Minute) - + // 每小时清理一次过期的上传会话 ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() - + for { select { case <-ctx.Done(): @@ -245,7 +245,7 @@ func startSessionCleaner(ctx context.Context, storageService *storage.Service) { } else { log.Println("Successfully cleaned expired upload sessions") } - + // 同时中止在S3存储桶中过期的分片上传 cleanupS3MultipartUploads(ctx, storageService) } @@ -261,7 +261,7 @@ func cleanupS3MultipartUploads(ctx context.Context, storageService *storage.Serv log.Printf("Failed to get pending sessions for cleanup: %v", err) return } - + for _, session := range sessions { if session.IsExpired() { log.Printf("Found expired session: uploadID=%s, key=%s", session.UploadID, session.Key) diff --git a/internal/api/auth_middleware.go b/internal/api/auth_middleware.go deleted file mode 100644 index 867baff..0000000 --- a/internal/api/auth_middleware.go +++ /dev/null @@ -1,57 +0,0 @@ -package api - -import ( - "encoding/base64" - "net/http" - "strings" -) - -// authMiddleware 处理 Basic Auth 校验 -func (h *S3Handler) authMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !h.authRequired { - next.ServeHTTP(w, r) - return - } - - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - h.requireAuth(w) - return - } - - if strings.HasPrefix(authHeader, "Basic ") { - payload := strings.TrimPrefix(authHeader, "Basic ") - decoded, err := base64.StdEncoding.DecodeString(payload) - if err != nil { - h.requireAuth(w) - return - } - - parts := strings.SplitN(string(decoded), ":", 2) - if len(parts) != 2 { - h.requireAuth(w) - return - } - - if parts[0] != h.accessKey { - h.sendS3Error(w, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.", "") - return - } - if parts[1] != h.secretKey { - h.sendS3Error(w, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.", "") - return - } - - next.ServeHTTP(w, r) - return - } - - h.requireAuth(w) - }) -} - -func (h *S3Handler) requireAuth(w http.ResponseWriter) { - w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"") - h.sendS3Error(w, "AccessDenied", "Access Denied", "") -} diff --git a/internal/api/s3_handler.go b/internal/api/s3_handler.go index bceb40d..7859d52 100644 --- a/internal/api/s3_handler.go +++ b/internal/api/s3_handler.go @@ -4,6 +4,7 @@ import ( "github.com/DullJZ/s3-balance/internal/balancer" "github.com/DullJZ/s3-balance/internal/bucket" "github.com/DullJZ/s3-balance/internal/metrics" + "github.com/DullJZ/s3-balance/internal/middleware" "github.com/DullJZ/s3-balance/internal/storage" "github.com/DullJZ/s3-balance/pkg/presigner" "github.com/gorilla/mux" @@ -75,7 +76,18 @@ func (h *S3Handler) RegisterS3Routes(router *mux.Router) { // Object operations - must be registered after multipart operations to avoid conflicts router.HandleFunc("/{bucket}/{key:.*}", h.handleObjectOperations).Methods("GET", "HEAD", "PUT", "DELETE") - // 添加认证中间件 - router.Use(h.virtualHostMiddleware) - router.Use(h.authMiddleware) + // 添加中间件 + router.Use(middleware.VirtualHost(middleware.VirtualHostConfig{ + Enabled: h.virtualHost, + BucketExists: func(name string) bool { + _, ok := h.bucketManager.GetBucket(name) + return ok + }, + })) + router.Use(middleware.BasicAuth(middleware.AuthConfig{ + Required: h.authRequired, + AccessKey: h.accessKey, + SecretKey: h.secretKey, + OnError: h.sendS3Error, + })) } diff --git a/internal/api/virtual_host_middleware.go b/internal/api/virtual_host_middleware.go deleted file mode 100644 index 0109c27..0000000 --- a/internal/api/virtual_host_middleware.go +++ /dev/null @@ -1,72 +0,0 @@ -package api - -import ( - "net" - "net/http" - "strings" -) - -// virtualHostMiddleware 支持根据 Host 头推断存储桶名称 -func (h *S3Handler) virtualHostMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !h.virtualHost { - next.ServeHTTP(w, r) - return - } - - bucketName := h.bucketFromHost(r.Host) - if bucketName == "" { - next.ServeHTTP(w, r) - return - } - - // 若路由中已包含桶名称则无需改写 - if strings.HasPrefix(r.URL.Path, "/"+bucketName) { - next.ServeHTTP(w, r) - return - } - - // 确保桶存在 - if _, ok := h.bucketManager.GetBucket(bucketName); !ok { - next.ServeHTTP(w, r) - return - } - - newPath := "/" + bucketName - if r.URL.Path != "/" { - newPath += r.URL.Path - } - - clone := r.Clone(r.Context()) - clone.URL.Path = newPath - clone.RequestURI = newPath - - next.ServeHTTP(w, clone) - }) -} - -func (h *S3Handler) bucketFromHost(host string) string { - if host == "" { - return "" - } - - cleanHost := host - if strings.Contains(host, ":") { - hostname, _, err := net.SplitHostPort(host) - if err == nil { - cleanHost = hostname - } - } - - parts := strings.Split(cleanHost, ".") - if len(parts) == 0 { - return "" - } - - candidate := parts[0] - if candidate == "" { - return "" - } - - return candidate -} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..94c68ff --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "encoding/base64" + "net/http" + "strings" +) + +// AuthConfig controls Basic Auth validation. +type AuthConfig struct { + Required bool + AccessKey string + SecretKey string + OnError func(http.ResponseWriter, string, string, string) +} + +// BasicAuth enforces static access/secret key authentication when Required is true. +func BasicAuth(cfg AuthConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !cfg.Required { + next.ServeHTTP(w, r) + return + } + + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, "Basic ") { + requireAuth(w, cfg) + return + } + + payload := strings.TrimPrefix(authHeader, "Basic ") + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + requireAuth(w, cfg) + return + } + + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) != 2 { + requireAuth(w, cfg) + return + } + + if parts[0] != cfg.AccessKey { + invokeOnError(w, cfg, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.") + return + } + if parts[1] != cfg.SecretKey { + invokeOnError(w, cfg, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.") + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func requireAuth(w http.ResponseWriter, cfg AuthConfig) { + w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"") + invokeOnError(w, cfg, "AccessDenied", "Access Denied") +} + +func invokeOnError(w http.ResponseWriter, cfg AuthConfig, code, message string) { + if cfg.OnError != nil { + cfg.OnError(w, code, message, "") + return + } + http.Error(w, message, http.StatusForbidden) +} diff --git a/internal/middleware/virtual_host.go b/internal/middleware/virtual_host.go new file mode 100644 index 0000000..464372a --- /dev/null +++ b/internal/middleware/virtual_host.go @@ -0,0 +1,73 @@ +package middleware + +import ( + "net" + "net/http" + "strings" +) + +// VirtualHostConfig controls host-style bucket resolution. +type VirtualHostConfig struct { + Enabled bool + BucketExists func(string) bool +} + +// VirtualHost rewrites host-style requests (bucket.example.com) into path-style paths. +func VirtualHost(cfg VirtualHostConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !cfg.Enabled { + next.ServeHTTP(w, r) + return + } + + bucket := bucketFromHost(r.Host) + if bucket == "" { + next.ServeHTTP(w, r) + return + } + + if cfg.BucketExists != nil && !cfg.BucketExists(bucket) { + next.ServeHTTP(w, r) + return + } + + if strings.HasPrefix(r.URL.Path, "/"+bucket) { + next.ServeHTTP(w, r) + return + } + + newPath := "/" + bucket + if r.URL.Path != "/" { + newPath += r.URL.Path + } + + clone := r.Clone(r.Context()) + clone.URL.Path = newPath + clone.RequestURI = newPath + + next.ServeHTTP(w, clone) + }) + } +} + +func bucketFromHost(host string) string { + if host == "" { + return "" + } + + hostname := host + if strings.Contains(host, ":") { + h, _, err := net.SplitHostPort(host) + if err == nil { + hostname = h + } + } + + parts := strings.Split(hostname, ".") + if len(parts) == 0 { + return "" + } + + return parts[0] +}