Files
s3-balance/internal/api/access_log.go
2025-10-05 01:58:54 +08:00

217 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"bufio"
"context"
"errors"
"log"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/mux"
)
type accessLogContextKey string
const (
errorCodeKey accessLogContextKey = "errorCode"
)
func (h *S3Handler) accessLogMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if h.storage == nil {
next.ServeHTTP(w, r)
return
}
start := time.Now()
lrw := newLoggingResponseWriter(w)
next.ServeHTTP(lrw, r)
vars := mux.Vars(r)
bucket := vars["bucket"]
key := vars["key"]
action := determineAccessAction(r, bucket, key)
if action == "" && bucket == "" && key == "" {
return
}
success := lrw.statusCode < 400
errMsg := ""
if !success {
// 优先使用context中的错误码auth错误设置的
if code, ok := r.Context().Value(errorCodeKey).(string); ok && code != "" {
errMsg = code
} else if code := lrw.Header().Get("X-Amz-Error-Code"); code != "" {
errMsg = code
} else {
errMsg = http.StatusText(lrw.statusCode)
}
}
size := calculateLogSize(r, lrw)
duration := time.Since(start)
h.recordAccessLog(r, action, bucket, key, size, success, errMsg, duration)
})
}
func (h *S3Handler) recordAccessLog(r *http.Request, action, bucket, key string, size int64, success bool, errMsg string, duration time.Duration) {
clientIP := extractClientIP(r)
userAgent := r.UserAgent()
host := r.Host
// 异步记录日志,避免阻塞请求响应
go func() {
if err := h.storage.RecordAccessLog(action, key, bucket, clientIP, userAgent, host, size, success, errMsg, duration.Milliseconds()); err != nil {
log.Printf("failed to record access log: %v", err)
}
}()
}
func (h *S3Handler) handleAuthError(w http.ResponseWriter, r *http.Request, code, message, resource string) {
// 将错误码存入context供accessLogMiddleware使用
ctx := context.WithValue(r.Context(), errorCodeKey, code)
*r = *r.WithContext(ctx)
h.sendS3Error(w, code, message, resource)
}
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int64
}
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
return &loggingResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
}
func (lrw *loggingResponseWriter) WriteHeader(statusCode int) {
lrw.statusCode = statusCode
lrw.ResponseWriter.WriteHeader(statusCode)
}
func (lrw *loggingResponseWriter) Write(p []byte) (int, error) {
n, err := lrw.ResponseWriter.Write(p)
lrw.bytesWritten += int64(n)
return n, err
}
func (lrw *loggingResponseWriter) Flush() {
if flusher, ok := lrw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (lrw *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := lrw.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return nil, nil, errors.New("http.Hijacker not supported")
}
func (lrw *loggingResponseWriter) Push(target string, opts *http.PushOptions) error {
if pusher, ok := lrw.ResponseWriter.(http.Pusher); ok {
return pusher.Push(target, opts)
}
return http.ErrNotSupported
}
func determineAccessAction(r *http.Request, bucket, key string) string {
method := r.Method
query := r.URL.Query()
if bucket == "" && key == "" {
if method == http.MethodGet {
return "list_buckets"
}
return strings.ToLower(method)
}
if key == "" {
switch method {
case http.MethodGet:
if _, ok := query["uploads"]; ok {
return "list_multipart_uploads"
}
return "list_objects"
case http.MethodHead:
return "head_bucket"
case http.MethodPut:
return "create_bucket"
case http.MethodDelete:
return "delete_bucket"
}
return strings.ToLower(method)
}
switch method {
case http.MethodGet:
if _, ok := query["uploads"]; ok {
return "list_multipart_uploads"
}
if _, ok := query["uploadId"]; ok {
return "list_multipart_parts"
}
return "download_object"
case http.MethodHead:
return "head_object"
case http.MethodPut:
if _, hasUploadID := query["uploadId"]; hasUploadID {
if _, hasPart := query["partNumber"]; hasPart {
return "upload_part"
}
}
return "upload_object"
case http.MethodDelete:
if _, ok := query["uploadId"]; ok {
return "abort_multipart_upload"
}
return "delete_object"
case http.MethodPost:
if _, ok := query["uploads"]; ok {
return "initiate_multipart_upload"
}
if _, ok := query["uploadId"]; ok {
return "complete_multipart_upload"
}
}
return strings.ToLower(method)
}
func calculateLogSize(r *http.Request, lrw *loggingResponseWriter) int64 {
switch r.Method {
case http.MethodPut, http.MethodPost:
// 对于上传请求,优先使用请求体大小(如果可用)
if r.ContentLength > 0 {
return r.ContentLength
}
// 对于分块传输chunkedContentLength 为 -1返回 0
return 0
}
// 对于 GET/HEAD 等请求,返回响应体大小
return lrw.bytesWritten
}
func extractClientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
return strings.TrimSpace(xrip)
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}