mirror of
https://github.com/DullJZ/s3-balance.git
synced 2026-06-28 06:21:23 +08:00
middleware
This commit is contained in:
@@ -64,14 +64,14 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create balancer: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 设置指标服务
|
||||
lb.SetMetrics(metricsService)
|
||||
|
||||
// 创建预签名URL生成器
|
||||
signer := presigner.NewPresigner(
|
||||
15*time.Minute, // 上传URL有效期
|
||||
60*time.Minute, // 下载URL有效期
|
||||
15*time.Minute, // 上传URL有效期
|
||||
60*time.Minute, // 下载URL有效期
|
||||
)
|
||||
|
||||
// 创建存储服务
|
||||
@@ -113,13 +113,13 @@ func main() {
|
||||
|
||||
// 设置路由
|
||||
router := mux.NewRouter()
|
||||
|
||||
|
||||
// 添加指标端点
|
||||
if cfg.Metrics.Enabled {
|
||||
router.Path(cfg.Metrics.Path).Handler(promhttp.Handler())
|
||||
log.Printf("Metrics server enabled at %s", cfg.Metrics.Path)
|
||||
}
|
||||
|
||||
|
||||
// 运行在S3兼容模式
|
||||
log.Println("Running in S3-compatible mode")
|
||||
s3Handler.RegisterS3Routes(router)
|
||||
@@ -144,7 +144,7 @@ func main() {
|
||||
log.Printf("Starting S3 Balance Service on %s", srv.Addr)
|
||||
log.Printf("Load balancing strategy: %s", cfg.Balancer.Strategy)
|
||||
log.Printf("Managed buckets: %d", len(cfg.Buckets))
|
||||
|
||||
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func main() {
|
||||
|
||||
// 优雅关闭
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
@@ -178,12 +178,12 @@ func corsMiddleware(next http.Handler) http.Handler {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -192,15 +192,15 @@ func corsMiddleware(next http.Handler) http.Handler {
|
||||
func loggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
|
||||
// 包装ResponseWriter以捕获状态码
|
||||
wrapped := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
|
||||
log.Printf(
|
||||
"[%s] %s %s %d %v",
|
||||
r.RemoteAddr,
|
||||
@@ -228,11 +228,11 @@ func startSessionCleaner(ctx context.Context, storageService *storage.Service) {
|
||||
go func() {
|
||||
// 初始延迟,避免启动时立即执行
|
||||
time.Sleep(1 * time.Minute)
|
||||
|
||||
|
||||
// 每小时清理一次过期的上传会话
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -245,7 +245,7 @@ func startSessionCleaner(ctx context.Context, storageService *storage.Service) {
|
||||
} else {
|
||||
log.Println("Successfully cleaned expired upload sessions")
|
||||
}
|
||||
|
||||
|
||||
// 同时中止在S3存储桶中过期的分片上传
|
||||
cleanupS3MultipartUploads(ctx, storageService)
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func cleanupS3MultipartUploads(ctx context.Context, storageService *storage.Serv
|
||||
log.Printf("Failed to get pending sessions for cleanup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
for _, session := range sessions {
|
||||
if session.IsExpired() {
|
||||
log.Printf("Found expired session: uploadID=%s, key=%s", session.UploadID, session.Key)
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// authMiddleware 处理 Basic Auth 校验
|
||||
func (h *S3Handler) authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.authRequired {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
h.requireAuth(w)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(authHeader, "Basic ") {
|
||||
payload := strings.TrimPrefix(authHeader, "Basic ")
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
h.requireAuth(w)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
h.requireAuth(w)
|
||||
return
|
||||
}
|
||||
|
||||
if parts[0] != h.accessKey {
|
||||
h.sendS3Error(w, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.", "")
|
||||
return
|
||||
}
|
||||
if parts[1] != h.secretKey {
|
||||
h.sendS3Error(w, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.", "")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
h.requireAuth(w)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *S3Handler) requireAuth(w http.ResponseWriter) {
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"")
|
||||
h.sendS3Error(w, "AccessDenied", "Access Denied", "")
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/DullJZ/s3-balance/internal/balancer"
|
||||
"github.com/DullJZ/s3-balance/internal/bucket"
|
||||
"github.com/DullJZ/s3-balance/internal/metrics"
|
||||
"github.com/DullJZ/s3-balance/internal/middleware"
|
||||
"github.com/DullJZ/s3-balance/internal/storage"
|
||||
"github.com/DullJZ/s3-balance/pkg/presigner"
|
||||
"github.com/gorilla/mux"
|
||||
@@ -75,7 +76,18 @@ func (h *S3Handler) RegisterS3Routes(router *mux.Router) {
|
||||
// Object operations - must be registered after multipart operations to avoid conflicts
|
||||
router.HandleFunc("/{bucket}/{key:.*}", h.handleObjectOperations).Methods("GET", "HEAD", "PUT", "DELETE")
|
||||
|
||||
// 添加认证中间件
|
||||
router.Use(h.virtualHostMiddleware)
|
||||
router.Use(h.authMiddleware)
|
||||
// 添加中间件
|
||||
router.Use(middleware.VirtualHost(middleware.VirtualHostConfig{
|
||||
Enabled: h.virtualHost,
|
||||
BucketExists: func(name string) bool {
|
||||
_, ok := h.bucketManager.GetBucket(name)
|
||||
return ok
|
||||
},
|
||||
}))
|
||||
router.Use(middleware.BasicAuth(middleware.AuthConfig{
|
||||
Required: h.authRequired,
|
||||
AccessKey: h.accessKey,
|
||||
SecretKey: h.secretKey,
|
||||
OnError: h.sendS3Error,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// virtualHostMiddleware 支持根据 Host 头推断存储桶名称
|
||||
func (h *S3Handler) virtualHostMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.virtualHost {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
bucketName := h.bucketFromHost(r.Host)
|
||||
if bucketName == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 若路由中已包含桶名称则无需改写
|
||||
if strings.HasPrefix(r.URL.Path, "/"+bucketName) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 确保桶存在
|
||||
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
newPath := "/" + bucketName
|
||||
if r.URL.Path != "/" {
|
||||
newPath += r.URL.Path
|
||||
}
|
||||
|
||||
clone := r.Clone(r.Context())
|
||||
clone.URL.Path = newPath
|
||||
clone.RequestURI = newPath
|
||||
|
||||
next.ServeHTTP(w, clone)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *S3Handler) bucketFromHost(host string) string {
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
cleanHost := host
|
||||
if strings.Contains(host, ":") {
|
||||
hostname, _, err := net.SplitHostPort(host)
|
||||
if err == nil {
|
||||
cleanHost = hostname
|
||||
}
|
||||
}
|
||||
|
||||
parts := strings.Split(cleanHost, ".")
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidate := parts[0]
|
||||
if candidate == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
70
internal/middleware/auth.go
Normal file
70
internal/middleware/auth.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AuthConfig controls Basic Auth validation.
|
||||
type AuthConfig struct {
|
||||
Required bool
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
OnError func(http.ResponseWriter, string, string, string)
|
||||
}
|
||||
|
||||
// BasicAuth enforces static access/secret key authentication when Required is true.
|
||||
func BasicAuth(cfg AuthConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !cfg.Required {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Basic ") {
|
||||
requireAuth(w, cfg)
|
||||
return
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(authHeader, "Basic ")
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
requireAuth(w, cfg)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
requireAuth(w, cfg)
|
||||
return
|
||||
}
|
||||
|
||||
if parts[0] != cfg.AccessKey {
|
||||
invokeOnError(w, cfg, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.")
|
||||
return
|
||||
}
|
||||
if parts[1] != cfg.SecretKey {
|
||||
invokeOnError(w, cfg, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func requireAuth(w http.ResponseWriter, cfg AuthConfig) {
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"")
|
||||
invokeOnError(w, cfg, "AccessDenied", "Access Denied")
|
||||
}
|
||||
|
||||
func invokeOnError(w http.ResponseWriter, cfg AuthConfig, code, message string) {
|
||||
if cfg.OnError != nil {
|
||||
cfg.OnError(w, code, message, "")
|
||||
return
|
||||
}
|
||||
http.Error(w, message, http.StatusForbidden)
|
||||
}
|
||||
73
internal/middleware/virtual_host.go
Normal file
73
internal/middleware/virtual_host.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// VirtualHostConfig controls host-style bucket resolution.
|
||||
type VirtualHostConfig struct {
|
||||
Enabled bool
|
||||
BucketExists func(string) bool
|
||||
}
|
||||
|
||||
// VirtualHost rewrites host-style requests (bucket.example.com) into path-style paths.
|
||||
func VirtualHost(cfg VirtualHostConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !cfg.Enabled {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
bucket := bucketFromHost(r.Host)
|
||||
if bucket == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.BucketExists != nil && !cfg.BucketExists(bucket) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.URL.Path, "/"+bucket) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
newPath := "/" + bucket
|
||||
if r.URL.Path != "/" {
|
||||
newPath += r.URL.Path
|
||||
}
|
||||
|
||||
clone := r.Clone(r.Context())
|
||||
clone.URL.Path = newPath
|
||||
clone.RequestURI = newPath
|
||||
|
||||
next.ServeHTTP(w, clone)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func bucketFromHost(host string) string {
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
hostname := host
|
||||
if strings.Contains(host, ":") {
|
||||
h, _, err := net.SplitHostPort(host)
|
||||
if err == nil {
|
||||
hostname = h
|
||||
}
|
||||
}
|
||||
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return parts[0]
|
||||
}
|
||||
Reference in New Issue
Block a user