mirror of
https://github.com/DullJZ/s3-balance.git
synced 2026-06-27 14:01:23 +08:00
217 lines
5.3 KiB
Go
217 lines
5.3 KiB
Go
package api
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"errors"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gorilla/mux"
|
||
)
|
||
|
||
type accessLogContextKey string
|
||
|
||
const (
|
||
errorCodeKey accessLogContextKey = "errorCode"
|
||
)
|
||
|
||
func (h *S3Handler) accessLogMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if h.storage == nil {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
start := time.Now()
|
||
lrw := newLoggingResponseWriter(w)
|
||
next.ServeHTTP(lrw, r)
|
||
|
||
vars := mux.Vars(r)
|
||
bucket := vars["bucket"]
|
||
key := vars["key"]
|
||
action := determineAccessAction(r, bucket, key)
|
||
if action == "" && bucket == "" && key == "" {
|
||
return
|
||
}
|
||
|
||
success := lrw.statusCode < 400
|
||
errMsg := ""
|
||
if !success {
|
||
// 优先使用context中的错误码(auth错误设置的)
|
||
if code, ok := r.Context().Value(errorCodeKey).(string); ok && code != "" {
|
||
errMsg = code
|
||
} else if code := lrw.Header().Get("X-Amz-Error-Code"); code != "" {
|
||
errMsg = code
|
||
} else {
|
||
errMsg = http.StatusText(lrw.statusCode)
|
||
}
|
||
}
|
||
|
||
size := calculateLogSize(r, lrw)
|
||
duration := time.Since(start)
|
||
h.recordAccessLog(r, action, bucket, key, size, success, errMsg, duration)
|
||
})
|
||
}
|
||
|
||
func (h *S3Handler) recordAccessLog(r *http.Request, action, bucket, key string, size int64, success bool, errMsg string, duration time.Duration) {
|
||
clientIP := extractClientIP(r)
|
||
userAgent := r.UserAgent()
|
||
host := r.Host
|
||
// 异步记录日志,避免阻塞请求响应
|
||
go func() {
|
||
if err := h.storage.RecordAccessLog(action, key, bucket, clientIP, userAgent, host, size, success, errMsg, duration.Milliseconds()); err != nil {
|
||
log.Printf("failed to record access log: %v", err)
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (h *S3Handler) handleAuthError(w http.ResponseWriter, r *http.Request, code, message, resource string) {
|
||
// 将错误码存入context,供accessLogMiddleware使用
|
||
ctx := context.WithValue(r.Context(), errorCodeKey, code)
|
||
*r = *r.WithContext(ctx)
|
||
|
||
h.sendS3Error(w, code, message, resource)
|
||
}
|
||
|
||
type loggingResponseWriter struct {
|
||
http.ResponseWriter
|
||
statusCode int
|
||
bytesWritten int64
|
||
}
|
||
|
||
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
|
||
return &loggingResponseWriter{
|
||
ResponseWriter: w,
|
||
statusCode: http.StatusOK,
|
||
}
|
||
}
|
||
|
||
func (lrw *loggingResponseWriter) WriteHeader(statusCode int) {
|
||
lrw.statusCode = statusCode
|
||
lrw.ResponseWriter.WriteHeader(statusCode)
|
||
}
|
||
|
||
func (lrw *loggingResponseWriter) Write(p []byte) (int, error) {
|
||
n, err := lrw.ResponseWriter.Write(p)
|
||
lrw.bytesWritten += int64(n)
|
||
return n, err
|
||
}
|
||
|
||
func (lrw *loggingResponseWriter) Flush() {
|
||
if flusher, ok := lrw.ResponseWriter.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
|
||
func (lrw *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||
if hijacker, ok := lrw.ResponseWriter.(http.Hijacker); ok {
|
||
return hijacker.Hijack()
|
||
}
|
||
return nil, nil, errors.New("http.Hijacker not supported")
|
||
}
|
||
|
||
func (lrw *loggingResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||
if pusher, ok := lrw.ResponseWriter.(http.Pusher); ok {
|
||
return pusher.Push(target, opts)
|
||
}
|
||
return http.ErrNotSupported
|
||
}
|
||
|
||
func determineAccessAction(r *http.Request, bucket, key string) string {
|
||
method := r.Method
|
||
query := r.URL.Query()
|
||
|
||
if bucket == "" && key == "" {
|
||
if method == http.MethodGet {
|
||
return "list_buckets"
|
||
}
|
||
return strings.ToLower(method)
|
||
}
|
||
|
||
if key == "" {
|
||
switch method {
|
||
case http.MethodGet:
|
||
if _, ok := query["uploads"]; ok {
|
||
return "list_multipart_uploads"
|
||
}
|
||
return "list_objects"
|
||
case http.MethodHead:
|
||
return "head_bucket"
|
||
case http.MethodPut:
|
||
return "create_bucket"
|
||
case http.MethodDelete:
|
||
return "delete_bucket"
|
||
}
|
||
return strings.ToLower(method)
|
||
}
|
||
|
||
switch method {
|
||
case http.MethodGet:
|
||
if _, ok := query["uploads"]; ok {
|
||
return "list_multipart_uploads"
|
||
}
|
||
if _, ok := query["uploadId"]; ok {
|
||
return "list_multipart_parts"
|
||
}
|
||
return "download_object"
|
||
case http.MethodHead:
|
||
return "head_object"
|
||
case http.MethodPut:
|
||
if _, hasUploadID := query["uploadId"]; hasUploadID {
|
||
if _, hasPart := query["partNumber"]; hasPart {
|
||
return "upload_part"
|
||
}
|
||
}
|
||
return "upload_object"
|
||
case http.MethodDelete:
|
||
if _, ok := query["uploadId"]; ok {
|
||
return "abort_multipart_upload"
|
||
}
|
||
return "delete_object"
|
||
case http.MethodPost:
|
||
if _, ok := query["uploads"]; ok {
|
||
return "initiate_multipart_upload"
|
||
}
|
||
if _, ok := query["uploadId"]; ok {
|
||
return "complete_multipart_upload"
|
||
}
|
||
}
|
||
|
||
return strings.ToLower(method)
|
||
}
|
||
|
||
func calculateLogSize(r *http.Request, lrw *loggingResponseWriter) int64 {
|
||
switch r.Method {
|
||
case http.MethodPut, http.MethodPost:
|
||
// 对于上传请求,优先使用请求体大小(如果可用)
|
||
if r.ContentLength > 0 {
|
||
return r.ContentLength
|
||
}
|
||
// 对于分块传输(chunked),ContentLength 为 -1,返回 0
|
||
return 0
|
||
}
|
||
// 对于 GET/HEAD 等请求,返回响应体大小
|
||
return lrw.bytesWritten
|
||
}
|
||
|
||
func extractClientIP(r *http.Request) string {
|
||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||
parts := strings.Split(xff, ",")
|
||
if len(parts) > 0 {
|
||
return strings.TrimSpace(parts[0])
|
||
}
|
||
}
|
||
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
|
||
return strings.TrimSpace(xrip)
|
||
}
|
||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||
if err != nil {
|
||
return r.RemoteAddr
|
||
}
|
||
return host
|
||
}
|