middleware

This commit is contained in:
DullJZ
2025-09-29 23:55:35 +08:00
parent 99fec072af
commit 15a39aa632
6 changed files with 174 additions and 148 deletions

View File

@@ -64,14 +64,14 @@ func main() {
if err != nil {
log.Fatalf("Failed to create balancer: %v", err)
}
// 设置指标服务
lb.SetMetrics(metricsService)
// 创建预签名URL生成器
signer := presigner.NewPresigner(
15*time.Minute, // 上传URL有效期
60*time.Minute, // 下载URL有效期
15*time.Minute, // 上传URL有效期
60*time.Minute, // 下载URL有效期
)
// 创建存储服务
@@ -113,13 +113,13 @@ func main() {
// 设置路由
router := mux.NewRouter()
// 添加指标端点
if cfg.Metrics.Enabled {
router.Path(cfg.Metrics.Path).Handler(promhttp.Handler())
log.Printf("Metrics server enabled at %s", cfg.Metrics.Path)
}
// 运行在S3兼容模式
log.Println("Running in S3-compatible mode")
s3Handler.RegisterS3Routes(router)
@@ -144,7 +144,7 @@ func main() {
log.Printf("Starting S3 Balance Service on %s", srv.Addr)
log.Printf("Load balancing strategy: %s", cfg.Balancer.Strategy)
log.Printf("Managed buckets: %d", len(cfg.Buckets))
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Failed to start server: %v", err)
}
@@ -157,7 +157,7 @@ func main() {
// 优雅关闭
log.Println("Shutting down server...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
@@ -178,12 +178,12 @@ func corsMiddleware(next http.Handler) http.Handler {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
@@ -192,15 +192,15 @@ func corsMiddleware(next http.Handler) http.Handler {
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装ResponseWriter以捕获状态码
wrapped := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
next.ServeHTTP(wrapped, r)
log.Printf(
"[%s] %s %s %d %v",
r.RemoteAddr,
@@ -228,11 +228,11 @@ func startSessionCleaner(ctx context.Context, storageService *storage.Service) {
go func() {
// 初始延迟,避免启动时立即执行
time.Sleep(1 * time.Minute)
// 每小时清理一次过期的上传会话
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
@@ -245,7 +245,7 @@ func startSessionCleaner(ctx context.Context, storageService *storage.Service) {
} else {
log.Println("Successfully cleaned expired upload sessions")
}
// 同时中止在S3存储桶中过期的分片上传
cleanupS3MultipartUploads(ctx, storageService)
}
@@ -261,7 +261,7 @@ func cleanupS3MultipartUploads(ctx context.Context, storageService *storage.Serv
log.Printf("Failed to get pending sessions for cleanup: %v", err)
return
}
for _, session := range sessions {
if session.IsExpired() {
log.Printf("Found expired session: uploadID=%s, key=%s", session.UploadID, session.Key)

View File

@@ -1,57 +0,0 @@
package api
import (
"encoding/base64"
"net/http"
"strings"
)
// authMiddleware 处理 Basic Auth 校验
func (h *S3Handler) authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !h.authRequired {
next.ServeHTTP(w, r)
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
h.requireAuth(w)
return
}
if strings.HasPrefix(authHeader, "Basic ") {
payload := strings.TrimPrefix(authHeader, "Basic ")
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
h.requireAuth(w)
return
}
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) != 2 {
h.requireAuth(w)
return
}
if parts[0] != h.accessKey {
h.sendS3Error(w, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.", "")
return
}
if parts[1] != h.secretKey {
h.sendS3Error(w, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.", "")
return
}
next.ServeHTTP(w, r)
return
}
h.requireAuth(w)
})
}
func (h *S3Handler) requireAuth(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"")
h.sendS3Error(w, "AccessDenied", "Access Denied", "")
}

View File

@@ -4,6 +4,7 @@ import (
"github.com/DullJZ/s3-balance/internal/balancer"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/DullJZ/s3-balance/internal/metrics"
"github.com/DullJZ/s3-balance/internal/middleware"
"github.com/DullJZ/s3-balance/internal/storage"
"github.com/DullJZ/s3-balance/pkg/presigner"
"github.com/gorilla/mux"
@@ -75,7 +76,18 @@ func (h *S3Handler) RegisterS3Routes(router *mux.Router) {
// Object operations - must be registered after multipart operations to avoid conflicts
router.HandleFunc("/{bucket}/{key:.*}", h.handleObjectOperations).Methods("GET", "HEAD", "PUT", "DELETE")
// 添加认证中间件
router.Use(h.virtualHostMiddleware)
router.Use(h.authMiddleware)
// 添加中间件
router.Use(middleware.VirtualHost(middleware.VirtualHostConfig{
Enabled: h.virtualHost,
BucketExists: func(name string) bool {
_, ok := h.bucketManager.GetBucket(name)
return ok
},
}))
router.Use(middleware.BasicAuth(middleware.AuthConfig{
Required: h.authRequired,
AccessKey: h.accessKey,
SecretKey: h.secretKey,
OnError: h.sendS3Error,
}))
}

View File

@@ -1,72 +0,0 @@
package api
import (
"net"
"net/http"
"strings"
)
// virtualHostMiddleware 支持根据 Host 头推断存储桶名称
func (h *S3Handler) virtualHostMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !h.virtualHost {
next.ServeHTTP(w, r)
return
}
bucketName := h.bucketFromHost(r.Host)
if bucketName == "" {
next.ServeHTTP(w, r)
return
}
// 若路由中已包含桶名称则无需改写
if strings.HasPrefix(r.URL.Path, "/"+bucketName) {
next.ServeHTTP(w, r)
return
}
// 确保桶存在
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
next.ServeHTTP(w, r)
return
}
newPath := "/" + bucketName
if r.URL.Path != "/" {
newPath += r.URL.Path
}
clone := r.Clone(r.Context())
clone.URL.Path = newPath
clone.RequestURI = newPath
next.ServeHTTP(w, clone)
})
}
func (h *S3Handler) bucketFromHost(host string) string {
if host == "" {
return ""
}
cleanHost := host
if strings.Contains(host, ":") {
hostname, _, err := net.SplitHostPort(host)
if err == nil {
cleanHost = hostname
}
}
parts := strings.Split(cleanHost, ".")
if len(parts) == 0 {
return ""
}
candidate := parts[0]
if candidate == "" {
return ""
}
return candidate
}

View File

@@ -0,0 +1,70 @@
package middleware
import (
"encoding/base64"
"net/http"
"strings"
)
// AuthConfig controls Basic Auth validation.
type AuthConfig struct {
Required bool
AccessKey string
SecretKey string
OnError func(http.ResponseWriter, string, string, string)
}
// BasicAuth enforces static access/secret key authentication when Required is true.
func BasicAuth(cfg AuthConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !cfg.Required {
next.ServeHTTP(w, r)
return
}
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, "Basic ") {
requireAuth(w, cfg)
return
}
payload := strings.TrimPrefix(authHeader, "Basic ")
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
requireAuth(w, cfg)
return
}
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) != 2 {
requireAuth(w, cfg)
return
}
if parts[0] != cfg.AccessKey {
invokeOnError(w, cfg, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.")
return
}
if parts[1] != cfg.SecretKey {
invokeOnError(w, cfg, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.")
return
}
next.ServeHTTP(w, r)
})
}
}
func requireAuth(w http.ResponseWriter, cfg AuthConfig) {
w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"")
invokeOnError(w, cfg, "AccessDenied", "Access Denied")
}
func invokeOnError(w http.ResponseWriter, cfg AuthConfig, code, message string) {
if cfg.OnError != nil {
cfg.OnError(w, code, message, "")
return
}
http.Error(w, message, http.StatusForbidden)
}

View File

@@ -0,0 +1,73 @@
package middleware
import (
"net"
"net/http"
"strings"
)
// VirtualHostConfig controls host-style bucket resolution.
type VirtualHostConfig struct {
Enabled bool
BucketExists func(string) bool
}
// VirtualHost rewrites host-style requests (bucket.example.com) into path-style paths.
func VirtualHost(cfg VirtualHostConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !cfg.Enabled {
next.ServeHTTP(w, r)
return
}
bucket := bucketFromHost(r.Host)
if bucket == "" {
next.ServeHTTP(w, r)
return
}
if cfg.BucketExists != nil && !cfg.BucketExists(bucket) {
next.ServeHTTP(w, r)
return
}
if strings.HasPrefix(r.URL.Path, "/"+bucket) {
next.ServeHTTP(w, r)
return
}
newPath := "/" + bucket
if r.URL.Path != "/" {
newPath += r.URL.Path
}
clone := r.Clone(r.Context())
clone.URL.Path = newPath
clone.RequestURI = newPath
next.ServeHTTP(w, clone)
})
}
}
func bucketFromHost(host string) string {
if host == "" {
return ""
}
hostname := host
if strings.Contains(host, ":") {
h, _, err := net.SplitHostPort(host)
if err == nil {
hostname = h
}
}
parts := strings.Split(hostname, ".")
if len(parts) == 0 {
return ""
}
return parts[0]
}