mirror of
https://github.com/DullJZ/s3-balance.git
synced 2026-07-02 16:41:22 +08:00
S3 signature V4 verify
This commit is contained in:
@@ -1,20 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/DullJZ/s3-validate/pkg/s3validate"
|
||||
)
|
||||
|
||||
// AuthConfig controls Basic Auth validation.
|
||||
type AuthConfig struct {
|
||||
// S3SignatureConfig controls S3 signature validation.
|
||||
type S3SignatureConfig struct {
|
||||
Required func() bool
|
||||
Credentials func() (string, string)
|
||||
OnError func(http.ResponseWriter, string, string, string)
|
||||
}
|
||||
|
||||
// BasicAuth enforces static access/secret key authentication when required.
|
||||
func BasicAuth(cfg AuthConfig) func(http.Handler) http.Handler {
|
||||
// credentialsProvider implements s3validate.CredentialsProvider interface.
|
||||
type credentialsProvider struct {
|
||||
getCredentials func() (string, string)
|
||||
}
|
||||
|
||||
func (p *credentialsProvider) SecretKey(ctx context.Context, accessKey string) (string, error) {
|
||||
expectedAccessKey, secretKey := p.getCredentials()
|
||||
if accessKey != expectedAccessKey {
|
||||
return "", fmt.Errorf("invalid access key")
|
||||
}
|
||||
return secretKey, nil
|
||||
}
|
||||
|
||||
// S3Signature enforces AWS Signature V4 authentication when required.
|
||||
func S3Signature(cfg S3SignatureConfig) func(http.Handler) http.Handler {
|
||||
verifier := &s3validate.Verifier{
|
||||
Credentials: &credentialsProvider{
|
||||
getCredentials: cfg.Credentials,
|
||||
},
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
required := false
|
||||
@@ -26,50 +47,21 @@ func BasicAuth(cfg AuthConfig) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Basic ") {
|
||||
requireAuth(w, cfg)
|
||||
return
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(authHeader, "Basic ")
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
result, err := verifier.Verify(r.Context(), r)
|
||||
if err != nil {
|
||||
requireAuth(w, cfg)
|
||||
invokeOnError(w, cfg, "SignatureDoesNotMatch", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
requireAuth(w, cfg)
|
||||
return
|
||||
}
|
||||
|
||||
accessKey, secretKey := "", ""
|
||||
if cfg.Credentials != nil {
|
||||
accessKey, secretKey = cfg.Credentials()
|
||||
}
|
||||
|
||||
if parts[0] != accessKey {
|
||||
invokeOnError(w, cfg, "InvalidAccessKeyId", "The AWS Access Key Id you provided does not match the configured key.")
|
||||
return
|
||||
}
|
||||
if parts[1] != secretKey {
|
||||
invokeOnError(w, cfg, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.")
|
||||
return
|
||||
}
|
||||
// Store verification result in context for potential use by handlers
|
||||
_ = result
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func requireAuth(w http.ResponseWriter, cfg AuthConfig) {
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=\"s3-balance\"")
|
||||
invokeOnError(w, cfg, "AccessDenied", "Access Denied")
|
||||
}
|
||||
|
||||
func invokeOnError(w http.ResponseWriter, cfg AuthConfig, code, message string) {
|
||||
func invokeOnError(w http.ResponseWriter, cfg S3SignatureConfig, code, message string) {
|
||||
if cfg.OnError != nil {
|
||||
cfg.OnError(w, code, message, "")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user