S3 signature V4 verify

This commit is contained in:
DullJZ
2025-10-01 20:36:53 +08:00
parent 1183775101
commit 48bc2d9b11
4 changed files with 36 additions and 41 deletions

View File

@@ -1,20 +1,41 @@
package middleware
import (
"encoding/base64"
"context"
"fmt"
"net/http"
"strings"
"github.com/DullJZ/s3-validate/pkg/s3validate"
)
// AuthConfig controls Basic Auth validation.
type AuthConfig struct {
// S3SignatureConfig controls S3 signature validation.
type S3SignatureConfig struct {
Required func() bool
Credentials func() (string, string)
OnError func(http.ResponseWriter, string, string, string)
}
// BasicAuth enforces static access/secret key authentication when required.
func BasicAuth(cfg AuthConfig) func(http.Handler) http.Handler {
// credentialsProvider implements s3validate.CredentialsProvider interface.
type credentialsProvider struct {
getCredentials func() (string, string)
}
func (p *credentialsProvider) SecretKey(ctx context.Context, accessKey string) (string, error) {
expectedAccessKey, secretKey := p.getCredentials()
if accessKey != expectedAccessKey {
return "", fmt.Errorf("invalid access key")
}
return secretKey, nil
}
// S3Signature enforces AWS Signature V4 authentication when required.
func S3Signature(cfg S3SignatureConfig) func(http.Handler) http.Handler {
verifier := &s3validate.Verifier{
Credentials: &credentialsProvider{
getCredentials: cfg.Credentials,
},
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
required := false
@@ -26,50 +47,21 @@ func BasicAuth(cfg AuthConfig) func(http.Handler) http.Handler {
return
}
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, "Basic ") {
requireAuth(w, cfg)
return
}
payload := strings.TrimPrefix(authHeader, "Basic ")
decoded, err := base64.StdEncoding.DecodeString(payload)
result, err := verifier.Verify(r.Context(), r)
if err != nil {
requireAuth(w, cfg)
invokeOnError(w, cfg, "SignatureDoesNotMatch", err.Error())
return
}
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) != 2 {
requireAuth(w, cfg)
return
}
accessKey, secretKey := "", ""
if cfg.Credentials != nil {
accessKey, secretKey = cfg.Credentials()
}
if parts[0] != accessKey {
invokeOnError(w, cfg, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.")
return
}
if parts[1] != secretKey {
invokeOnError(w, cfg, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.")
return
}
// Store verification result in context for potential use by handlers
_ = result
next.ServeHTTP(w, r)
})
}
}
func requireAuth(w http.ResponseWriter, cfg AuthConfig) {
w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"")
invokeOnError(w, cfg, "AccessDenied", "Access Denied")
}
func invokeOnError(w http.ResponseWriter, cfg AuthConfig, code, message string) {
func invokeOnError(w http.ResponseWriter, cfg S3SignatureConfig, code, message string) {
if cfg.OnError != nil {
cfg.OnError(w, code, message, "")
return