Merge pull request #6 from DullJZ/add-api

Add api
This commit is contained in:
DullJZ
2025-11-15 12:32:09 +08:00
committed by GitHub
25 changed files with 1764 additions and 65 deletions

View File

@@ -19,6 +19,21 @@ jobs:
run: |
echo "VERSION=$(cat VERSION | tr -d '\n')" >> $GITHUB_ENV
- name: Download frontend assets
run: |
echo "Downloading latest frontend build from s3-balance-web..."
LATEST_RELEASE=$(curl -s https://api.github.com/repos/DullJZ/s3-balance-web/releases/latest | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/')
echo "Latest frontend release: $LATEST_RELEASE"
curl -L -o dist.tar.gz "https://github.com/DullJZ/s3-balance-web/releases/download/$LATEST_RELEASE/dist.tar.gz"
mkdir -p internal/webui/dist
tar -xzf dist.tar.gz -C internal/webui/
rm dist.tar.gz
echo "Frontend assets downloaded and extracted to internal/webui/dist"
ls -la internal/webui/dist/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
with:
@@ -62,6 +77,21 @@ jobs:
run: |
echo "VERSION=$(cat VERSION | tr -d '\n')" >> $GITHUB_ENV
- name: Download frontend assets
run: |
echo "Downloading latest frontend build from s3-balance-web..."
LATEST_RELEASE=$(curl -s https://api.github.com/repos/DullJZ/s3-balance-web/releases/latest | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/')
echo "Latest frontend release: $LATEST_RELEASE"
curl -L -o dist.tar.gz "https://github.com/DullJZ/s3-balance-web/releases/download/$LATEST_RELEASE/dist.tar.gz"
mkdir -p internal/webui/dist
tar -xzf dist.tar.gz -C internal/webui/
rm dist.tar.gz
echo "Frontend assets downloaded and extracted to internal/webui/dist"
ls -la internal/webui/dist/
- name: Setup Go
uses: actions/setup-go@v2
with:

10
.gitignore vendored
View File

@@ -46,3 +46,13 @@ logs/
# Environment variables
.env
.env.local
# AI doc
AGENTS.md
CLAUDE.md
docs/
# Generated files
s3-balance
dist/

View File

@@ -17,7 +17,11 @@ import (
"github.com/DullJZ/s3-balance/internal/config"
"github.com/DullJZ/s3-balance/internal/database"
"github.com/DullJZ/s3-balance/internal/metrics"
"github.com/DullJZ/s3-balance/internal/middleware"
"github.com/DullJZ/s3-balance/internal/scheduler"
"github.com/DullJZ/s3-balance/internal/storage"
"github.com/DullJZ/s3-balance/internal/web"
"github.com/DullJZ/s3-balance/internal/webui"
"github.com/DullJZ/s3-balance/pkg/presigner"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -26,9 +30,17 @@ import (
func main() {
// 解析命令行参数
var configFile string
var onlyWeb bool
flag.StringVar(&configFile, "config", "config/config.yaml", "Path to configuration file")
flag.BoolVar(&onlyWeb, "only-web", false, "Only serve web UI, no backend services")
flag.Parse()
// 如果是只提供Web前端模式
if onlyWeb {
startWebOnlyMode(configFile)
return
}
// 创建配置管理器
configManager, err := config.NewManager(configFile)
if err != nil {
@@ -45,11 +57,14 @@ func main() {
}
defer database.Close()
// 创建存储服务
storageService := storage.NewService(database.GetDB())
// 创建指标服务
metricsService := metrics.New()
// 创建存储桶管理器
bucketManager, err := bucket.NewManager(cfg, metricsService)
bucketManager, err := bucket.NewManager(cfg, metricsService, storageService)
if err != nil {
log.Fatalf("Failed to create bucket manager: %v", err)
}
@@ -74,12 +89,14 @@ func main() {
60*time.Minute, // 下载URL有效期
)
// 创建存储服务
storageService := storage.NewService(database.GetDB())
// 启动定期清理过期上传会话的任务
startSessionCleaner(ctx, storageService)
// 启动月度统计归档任务(每小时检查一次)
monthlyArchiver := scheduler.NewMonthlyArchiver(storageService, 1*time.Hour)
monthlyArchiver.Start()
defer monthlyArchiver.Stop()
// 创建S3兼容API处理器
s3Handler := api.NewS3Handler(
bucketManager,
@@ -124,6 +141,32 @@ func main() {
log.Printf("Metrics server enabled at %s", cfg.Metrics.Path)
}
// 注册管理API路由如果启用
// 必须在S3路由之前注册因为S3路由使用 /{bucket} 通配符会匹配所有路径
if cfg.API.Enabled {
log.Println("Management API enabled")
adminHandler := api.NewAdminHandler(bucketManager, lb, cfg, configManager)
statsHandler := api.NewStatsHandler(storageService)
// 创建子路由器并应用中间件
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(corsMiddleware) // 先应用 CORS 中间件,处理 OPTIONS 预检请求
apiRouter.Use(middleware.TokenAuthMiddleware(cfg.API.Token))
adminHandler.RegisterRoutes(apiRouter)
statsHandler.RegisterRoutes(apiRouter)
log.Printf("Management API endpoints available at /api/*")
}
// 注册Web管理界面
distSubFS, err := webui.GetDistFS()
if err != nil {
log.Fatalf("Failed to load embedded web UI: %v", err)
}
webHandler := web.NewHandler(distSubFS)
router.PathPrefix("/web").Handler(http.StripPrefix("/web", webHandler))
log.Println("Web UI available at /web")
// 运行在S3兼容模式
log.Println("Running in S3-compatible mode")
s3Handler.RegisterS3Routes(router)
@@ -274,3 +317,77 @@ func cleanupS3MultipartUploads(_ context.Context, storageService *storage.Servic
}
}
}
// startWebOnlyMode 只启动Web前端服务不启动后端服务
func startWebOnlyMode(configFile string) {
log.Println("Starting in web-only mode (no backend services)")
// 加载配置文件以获取端口等信息
configManager, err := config.NewManager(configFile)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
defer configManager.Close()
cfg := configManager.GetConfig()
// 创建路由器
router := mux.NewRouter()
// 加载嵌入的前端资源
distSubFS, err := webui.GetDistFS()
if err != nil {
log.Fatalf("Failed to load embedded web UI: %v", err)
}
// 注册Web前端路由
webHandler := web.NewHandler(distSubFS)
// 根路径重定向到 /web
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/web/", http.StatusMovedPermanently)
})
// Web UI 路由
router.PathPrefix("/web").Handler(http.StripPrefix("/web", webHandler))
// 添加 CORS 和日志中间件
router.Use(corsMiddleware)
router.Use(loggingMiddleware)
// 使用配置文件中的端口
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
srv := &http.Server{
Addr: addr,
Handler: router,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
}
log.Println("Web UI available at /web")
log.Printf("Starting web server on %s", srv.Addr)
// 启动服务器
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Failed to start server: %v", err)
}
}()
// 等待中断信号
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
// 优雅关闭
log.Println("Shutting down web server...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("Server shutdown error: %v", err)
}
log.Println("Web server stopped")
}

View File

@@ -48,6 +48,9 @@ buckets:
enabled: true
path_style: false # AWS S3使用虚拟主机风格
virtual: false # 这是真实存储桶
operation_limits:
type_a: 0 # 类型A操作写入类上限0表示不限制
type_b: 0 # 类型B操作读取类上限0表示不限制
# 真实存储桶 - MinIO用于存储数据对客户端隐藏
- name: "my-bucket-2"
@@ -60,6 +63,9 @@ buckets:
enabled: true
path_style: true # MinIO通常使用路径风格
virtual: false # 这是真实存储桶
operation_limits:
type_a: 0
type_b: 0
# 虚拟存储桶 - user-bucket-1对客户端可见的唯一存储桶
- name: "user-bucket-1"
@@ -96,6 +102,9 @@ buckets:
enabled: true
path_style: false
virtual: false # 这是真实存储桶
operation_limits:
type_a: 0
type_b: 0
# 负载均衡配置
balancer:
@@ -124,19 +133,20 @@ metrics:
s3api:
# 客户端连接用的Access Key
access_key: "AKIAIOSFODNN7EXAMPLE"
# 客户端连接用的Secret Key
secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
# 是否使用虚拟主机模式
virtual_host: false
# 工作模式:
# false预签名重定向模式客户端直接与后端存储交互
# true (默认)代理模式数据通过S3 Balance服务器传输
# false预签名重定向模式客户端下载直接重定向到与后端存储
# true (默认)代理模式数据通过S3 Balance服务器中转传输
# 该选项仅适用于下载,上传操作始终为全代理模式
proxy_mode: true
# 是否需要认证(开启后使用 Basic Auth凭据来自 access_key/secret_key
# 是否需要认证(使用配置的 access_key/secret_key
auth_required: true
# 用于签名验证的Host可选
@@ -144,3 +154,12 @@ s3api:
# 留空则使用请求中的 Host 头
# 示例: "s3.example.com" 或 "s3.example.com:8080"
host: ""
# 管理API配置
api:
# 是否启用管理API
enabled: true
# API访问令牌用于管理接口的身份验证
# 请修改为强密码,建议使用随机生成的长字符串
token: "your-secure-api-token-change-this"

View File

@@ -99,6 +99,67 @@
"title": "QPS (操作/秒)",
"type": "timeseries"
},
{
"datasource": "Prometheus",
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"custom": {
"align": "auto",
"displayMode": "auto"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 0
},
"id": 12,
"options": {
"footer": {
"fields": "",
"reducer": [
"sum"
],
"show": false
},
"showHeader": true
},
"targets": [
{
"expr": "sum by (bucket, category)(increase(s3_balance_backend_operations_total[$__range]))",
"instant": true,
"legendFormat": "",
"refId": "A"
}
],
"timeFrom": "now/M",
"title": "后端操作次数 (本自然月)",
"transformations": [
{
"id": "labelsToFields",
"options": {
"valueLabel": "操作次数"
}
}
],
"type": "table"
},
{
"datasource": "Prometheus",
"fieldConfig": {
@@ -334,5 +395,5 @@
"timezone": "",
"title": "S3 Balance 监控面板",
"uid": "s3-balance-monitoring",
"version": 1
"version": 2
}

View File

@@ -0,0 +1,234 @@
package api
import (
"encoding/json"
"net/http"
"time"
"github.com/DullJZ/s3-balance/internal/balancer"
"github.com/DullJZ/s3-balance/internal/bucket"
"github.com/DullJZ/s3-balance/internal/config"
"github.com/gorilla/mux"
)
// AdminHandler 管理API处理器
type AdminHandler struct {
bucketManager *bucket.Manager
balancer *balancer.Balancer
config *config.Config
configManager *config.Manager
}
// NewAdminHandler 创建新的管理API处理器
func NewAdminHandler(
bucketManager *bucket.Manager,
balancer *balancer.Balancer,
cfg *config.Config,
configManager *config.Manager,
) *AdminHandler {
return &AdminHandler{
bucketManager: bucketManager,
balancer: balancer,
config: cfg,
configManager: configManager,
}
}
// BucketResponse 存储桶响应结构
type BucketResponse struct {
Name string `json:"name"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
MaxSize string `json:"max_size"`
MaxSizeBytes int64 `json:"max_size_bytes"`
UsedSize int64 `json:"used_size"`
AvailableSize int64 `json:"available_size"`
UsagePercent float64 `json:"usage_percent"`
Weight int `json:"weight"`
Enabled bool `json:"enabled"`
Available bool `json:"available"`
Virtual bool `json:"virtual"`
LastChecked time.Time `json:"last_checked"`
OperationCountA int64 `json:"operation_count_a"`
OperationCountB int64 `json:"operation_count_b"`
OperationLimits struct {
TypeA int `json:"type_a"`
TypeB int `json:"type_b"`
} `json:"operation_limits"`
}
// BucketsListResponse 存储桶列表响应结构
type BucketsListResponse struct {
Total int `json:"total"`
Buckets []BucketResponse `json:"buckets"`
}
// HealthResponse 健康状态响应结构
type HealthResponse struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
LoadBalancer string `json:"load_balancer_strategy"`
TotalBuckets int `json:"total_buckets"`
AvailableBuckets int `json:"available_buckets"`
Database string `json:"database_type"`
}
// RegisterRoutes 注册管理API路由
// 注意: router 参数应该是已经带有 /api 前缀的子路由器
func (h *AdminHandler) RegisterRoutes(router *mux.Router) {
// 注册路由,同时支持 OPTIONS 方法用于 CORS 预检
router.HandleFunc("/buckets", h.ListBuckets).Methods(http.MethodGet, http.MethodOptions)
router.HandleFunc("/buckets/{name}", h.GetBucketDetail).Methods(http.MethodGet, http.MethodOptions)
router.HandleFunc("/health", h.GetHealth).Methods(http.MethodGet, http.MethodOptions)
router.HandleFunc("/config", h.GetConfig).Methods(http.MethodGet, http.MethodOptions)
router.HandleFunc("/config", h.UpdateConfig).Methods(http.MethodPost, http.MethodOptions)
}
// ListBuckets 获取存储桶列表
func (h *AdminHandler) ListBuckets(w http.ResponseWriter, r *http.Request) {
buckets := h.bucketManager.GetAllBuckets()
response := BucketsListResponse{
Total: len(buckets),
Buckets: make([]BucketResponse, 0, len(buckets)),
}
for _, b := range buckets {
bucketResp := h.convertBucketInfo(b)
response.Buckets = append(response.Buckets, bucketResp)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// GetBucketDetail 获取存储桶详情
func (h *AdminHandler) GetBucketDetail(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
name := vars["name"]
bucketInfo, exists := h.bucketManager.GetBucket(name)
if !exists {
http.Error(w, `{"error": "bucket not found"}`, http.StatusNotFound)
return
}
response := h.convertBucketInfo(bucketInfo)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// GetHealth 获取系统健康状态
func (h *AdminHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
buckets := h.bucketManager.GetAllBuckets()
availableBuckets := h.bucketManager.GetAvailableBuckets()
status := "healthy"
if len(availableBuckets) == 0 {
status = "unhealthy"
} else if len(availableBuckets) < len(buckets)/2 {
status = "degraded"
}
response := HealthResponse{
Status: status,
Timestamp: time.Now(),
LoadBalancer: h.config.Balancer.Strategy,
TotalBuckets: len(buckets),
AvailableBuckets: len(availableBuckets),
Database: h.config.Database.Type,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// GetConfig 获取当前配置
func (h *AdminHandler) GetConfig(w http.ResponseWriter, r *http.Request) {
if h.configManager == nil {
http.Error(w, `{"error": "config manager not available"}`, http.StatusInternalServerError)
return
}
currentConfig := h.configManager.GetConfig()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(currentConfig)
}
// UpdateConfig 更新配置
func (h *AdminHandler) UpdateConfig(w http.ResponseWriter, r *http.Request) {
if h.configManager == nil {
http.Error(w, `{"error": "config manager not available"}`, http.StatusInternalServerError)
return
}
// 解析请求体
var newConfig config.Config
if err := json.NewDecoder(r.Body).Decode(&newConfig); err != nil {
http.Error(w, `{"error": "invalid JSON format: `+err.Error()+`"}`, http.StatusBadRequest)
return
}
// 设置默认值
newConfig.SetDefaults()
// 更新配置
if err := h.configManager.UpdateConfig(&newConfig); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "validation failed",
"message": err.Error(),
})
return
}
// 返回成功响应
response := map[string]interface{}{
"success": true,
"message": "Configuration updated successfully. Changes will take effect automatically.",
"config": &newConfig,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// convertBucketInfo 转换BucketInfo为BucketResponse
func (h *AdminHandler) convertBucketInfo(b *bucket.BucketInfo) BucketResponse {
resp := BucketResponse{
Name: b.Config.Name,
Endpoint: b.Config.Endpoint,
Region: b.Config.Region,
MaxSize: b.Config.MaxSize,
MaxSizeBytes: b.Config.MaxSizeBytes,
UsedSize: b.UsedSize,
Weight: b.Config.Weight,
Enabled: b.Config.Enabled,
Available: b.Available,
Virtual: b.Config.Virtual,
LastChecked: b.LastChecked,
OperationCountA: b.GetOperationCount(bucket.OperationTypeA),
OperationCountB: b.GetOperationCount(bucket.OperationTypeB),
}
resp.OperationLimits.TypeA = b.Config.OperationLimits.TypeA
resp.OperationLimits.TypeB = b.Config.OperationLimits.TypeB
// 计算可用空间
if b.Config.MaxSizeBytes > 0 {
resp.AvailableSize = b.Config.MaxSizeBytes - b.UsedSize
if resp.AvailableSize < 0 {
resp.AvailableSize = 0
}
// 计算使用百分比
resp.UsagePercent = float64(b.UsedSize) / float64(b.Config.MaxSizeBytes) * 100
} else {
resp.AvailableSize = -1 // -1 表示无限制
resp.UsagePercent = 0
}
return resp
}

View File

@@ -76,6 +76,8 @@ func (h *S3Handler) handleUploadPart(w http.ResponseWriter, r *http.Request) {
return
}
h.recordBackendOperation(targetBucket, bucket.OperationTypeA)
// 检查当前已上传大小 + 本次分片大小是否超过bucket剩余空间
currentSize, err := h.storage.GetUploadSessionSize(uploadID)
if err != nil {
@@ -213,6 +215,8 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request
return
}
h.recordBackendOperation(targetBucket, bucket.OperationTypeA)
// 初始化分片上传
ctx := context.Background()
createResp, err := targetBucket.Client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
@@ -280,14 +284,16 @@ func (h *S3Handler) handleListMultipartUploads(w http.ResponseWriter, r *http.Re
// 降级到遍历所有存储桶的方式
ctx := context.Background()
allBuckets := h.bucketManager.GetAllBuckets()
for _, bucket := range allBuckets {
if bucket.IsVirtual() {
for _, realBucket := range allBuckets {
if realBucket.IsVirtual() {
continue
}
h.recordBackendOperation(realBucket, bucket.OperationTypeB)
// 列出每个真实存储桶的分片上传
listResp, err := bucket.Client.ListMultipartUploads(ctx, &s3.ListMultipartUploadsInput{
Bucket: aws.String(bucket.Config.Name),
listResp, err := realBucket.Client.ListMultipartUploads(ctx, &s3.ListMultipartUploadsInput{
Bucket: aws.String(realBucket.Config.Name),
KeyMarker: aws.String(keyMarker),
UploadIdMarker: aws.String(uploadIdMarker),
Prefix: aws.String(prefix),
@@ -295,7 +301,7 @@ func (h *S3Handler) handleListMultipartUploads(w http.ResponseWriter, r *http.Re
MaxUploads: aws.Int32(int32(maxUploads)),
})
if err != nil {
log.Printf("Failed to list multipart uploads for bucket %s: %v", bucket.Config.Name, err)
log.Printf("Failed to list multipart uploads for bucket %s: %v", realBucket.Config.Name, err)
continue
}
@@ -409,21 +415,22 @@ func (h *S3Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Requ
if err != nil {
// 如果没有找到映射,尝试查询所有真实存储桶
allBuckets := h.bucketManager.GetAllBuckets()
for _, bucket := range allBuckets {
if bucket.IsVirtual() {
for _, realBucket := range allBuckets {
if realBucket.IsVirtual() {
continue
}
// 尝试列出分片,如果成功则说明上传在这个桶中
ctx := context.Background()
_, err := bucket.Client.ListParts(ctx, &s3.ListPartsInput{
Bucket: aws.String(bucket.Config.Name),
h.recordBackendOperation(realBucket, bucket.OperationTypeB)
_, err := realBucket.Client.ListParts(ctx, &s3.ListPartsInput{
Bucket: aws.String(realBucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)),
MaxParts: aws.Int32(1), // 只检查是否存在
})
if err == nil {
targetBucket = bucket
targetBucket = realBucket
break
}
}
@@ -446,6 +453,7 @@ func (h *S3Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Requ
}
// 列出分片
h.recordBackendOperation(targetBucket, bucket.OperationTypeB)
ctx := context.Background()
listResp, err := targetBucket.Client.ListParts(ctx, &s3.ListPartsInput{
Bucket: aws.String(targetBucket.Config.Name),
@@ -570,6 +578,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
// 完成分片上传
ctx := context.Background()
h.recordBackendOperation(targetBucket, bucket.OperationTypeA)
sort.SliceStable(completeReq.Parts, func(i, j int) bool {
return completeReq.Parts[i].PartNumber < completeReq.Parts[j].PartNumber
})
@@ -611,6 +620,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
// 获取完成上传后的对象大小
var objectSize int64
h.recordBackendOperation(targetBucket, bucket.OperationTypeB)
headResp, err := targetBucket.Client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(targetBucket.Config.Name),
Key: aws.String(key),
@@ -658,6 +668,7 @@ func getAPIError(err error) (smithy.APIError, bool) {
// abortMultipartUploadInternal 内部方法向后端S3发送中止分片上传请求
func (h *S3Handler) abortMultipartUploadInternal(targetBucket *bucket.BucketInfo, key, uploadID string) error {
h.recordBackendOperation(targetBucket, bucket.OperationTypeA)
ctx := context.Background()
_, err := targetBucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(targetBucket.Config.Name),
@@ -718,6 +729,7 @@ func (h *S3Handler) handleAbortMultipartUpload(w http.ResponseWriter, r *http.Re
// 中止分片上传
ctx := context.Background()
h.recordBackendOperation(targetBucket, bucket.OperationTypeA)
_, err := targetBucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(targetBucket.Config.Name),
Key: aws.String(key),

View File

@@ -70,6 +70,8 @@ func (h *S3Handler) handleGetObject(w http.ResponseWriter, r *http.Request, buck
h.sendS3Error(w, "InternalError", "Mapped real bucket not found", key)
return
}
h.recordBackendOperation(bucket1, bucket.OperationTypeB)
}
// 生成预签名下载URL
@@ -228,6 +230,8 @@ func (h *S3Handler) handlePutObject(w http.ResponseWriter, r *http.Request, buck
return
}
h.recordBackendOperation(targetBucket, bucket.OperationTypeA)
// 生成预签名上传URL
uploadInfo, err := h.presigner.GenerateUploadURL(
context.Background(),
@@ -295,7 +299,7 @@ func (h *S3Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request, b
return
}
var bucket *bucket.BucketInfo
var targetBucket *bucket.BucketInfo
var err error
if requestedBucket.IsVirtual() {
@@ -308,7 +312,7 @@ func (h *S3Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request, b
}
// 获取映射到的真实存储桶
bucket, ok = h.bucketManager.GetBucket(mapping.RealBucketName)
targetBucket, ok = h.bucketManager.GetBucket(mapping.RealBucketName)
if !ok {
h.sendS3Error(w, "InternalError", "Mapped real bucket not found", key)
return
@@ -319,10 +323,12 @@ func (h *S3Handler) handleDeleteObject(w http.ResponseWriter, r *http.Request, b
return
}
h.recordBackendOperation(targetBucket, bucket.OperationTypeA)
// 生成预签名删除URL
deleteInfo, err := h.presigner.GenerateDeleteURL(
context.Background(),
bucket,
targetBucket,
key,
)
if err != nil {

View File

@@ -0,0 +1,36 @@
package api
import (
"log"
"github.com/DullJZ/s3-balance/internal/bucket"
)
// recordBackendOperation increments backend operation counters and disables the bucket if limits are exceeded.
func (h *S3Handler) recordBackendOperation(b *bucket.BucketInfo, category bucket.OperationCategory) {
if b == nil {
return
}
if h.metrics != nil {
h.metrics.RecordBackendOperation(b.Config.Name, string(category))
}
var disabled bool
if h.storage != nil {
newCount, err := h.storage.IncrementBucketOperation(b.Config.Name, string(category))
if err != nil {
log.Printf("failed to persist backend operation count for bucket %s: %v", b.Config.Name, err)
disabled = b.RecordOperation(category)
} else {
disabled = b.SetOperationCount(category, newCount)
}
} else {
disabled = b.RecordOperation(category)
}
if disabled {
log.Printf("Bucket %s disabled after exceeding %s-type operation limit", b.Config.Name, category)
}
}

View File

@@ -0,0 +1,194 @@
package api
import (
"encoding/json"
"log"
"net/http"
"strconv"
"time"
"github.com/DullJZ/s3-balance/internal/storage"
"github.com/gorilla/mux"
)
// StatsHandler 统计数据处理器
type StatsHandler struct {
storage *storage.Service
}
// NewStatsHandler 创建统计处理器
func NewStatsHandler(storage *storage.Service) *StatsHandler {
return &StatsHandler{
storage: storage,
}
}
// RegisterRoutes 注册统计API路由
func (h *StatsHandler) RegisterRoutes(router *mux.Router) {
// 注意: router 参数应该是已经带有 /api 前缀的子路由器
router.HandleFunc("/stats/monthly", h.GetCurrentMonthStats).Methods(http.MethodGet, http.MethodOptions)
router.HandleFunc("/stats/monthly/{year}/{month}", h.GetMonthlyStats).Methods(http.MethodGet, http.MethodOptions)
router.HandleFunc("/stats/monthly/range", h.GetMonthlyStatsRange).Methods(http.MethodGet, http.MethodOptions)
router.HandleFunc("/stats/bucket/{bucket}/history", h.GetBucketHistory).Methods(http.MethodGet, http.MethodOptions)
}
// MonthlyStatsResponse 月度统计响应
type MonthlyStatsResponse struct {
Year int `json:"year"`
Month int `json:"month"`
Bucket string `json:"bucket"`
Stats BucketOperationCounts `json:"stats"`
}
// BucketOperationCounts 存储桶操作计数
type BucketOperationCounts struct {
OperationCountA int64 `json:"operation_count_a"`
OperationCountB int64 `json:"operation_count_b"`
Total int64 `json:"total"`
}
// GetCurrentMonthStats 获取当前月份的统计
func (h *StatsHandler) GetCurrentMonthStats(w http.ResponseWriter, r *http.Request) {
stats, err := h.storage.GetCurrentMonthStats()
if err != nil {
log.Printf("Failed to get current month stats: %v", err)
http.Error(w, "Failed to fetch statistics", http.StatusInternalServerError)
return
}
response := h.formatMonthlyStats(stats)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// GetMonthlyStats 获取指定月份的统计
func (h *StatsHandler) GetMonthlyStats(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
year, err := strconv.Atoi(vars["year"])
if err != nil {
http.Error(w, "Invalid year", http.StatusBadRequest)
return
}
month, err := strconv.Atoi(vars["month"])
if err != nil || month < 1 || month > 12 {
http.Error(w, "Invalid month", http.StatusBadRequest)
return
}
stats, err := h.storage.GetMonthlyStats(year, month)
if err != nil {
log.Printf("Failed to get monthly stats: %v", err)
http.Error(w, "Failed to fetch statistics", http.StatusInternalServerError)
return
}
response := h.formatMonthlyStats(stats)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// GetMonthlyStatsRange 获取时间范围内的统计
func (h *StatsHandler) GetMonthlyStatsRange(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
startYear, err := strconv.Atoi(query.Get("start_year"))
if err != nil {
http.Error(w, "Invalid start_year", http.StatusBadRequest)
return
}
startMonth, err := strconv.Atoi(query.Get("start_month"))
if err != nil || startMonth < 1 || startMonth > 12 {
http.Error(w, "Invalid start_month", http.StatusBadRequest)
return
}
endYear, err := strconv.Atoi(query.Get("end_year"))
if err != nil {
http.Error(w, "Invalid end_year", http.StatusBadRequest)
return
}
endMonth, err := strconv.Atoi(query.Get("end_month"))
if err != nil || endMonth < 1 || endMonth > 12 {
http.Error(w, "Invalid end_month", http.StatusBadRequest)
return
}
stats, err := h.storage.GetMonthlyStatsRange(startYear, startMonth, endYear, endMonth)
if err != nil {
log.Printf("Failed to get monthly stats range: %v", err)
http.Error(w, "Failed to fetch statistics", http.StatusInternalServerError)
return
}
response := h.formatMonthlyStats(stats)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// GetBucketHistory 获取指定存储桶的历史统计
func (h *StatsHandler) GetBucketHistory(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucket := vars["bucket"]
// 获取查询参数中的月份数默认12个月
months := 12
if monthsStr := r.URL.Query().Get("months"); monthsStr != "" {
if m, err := strconv.Atoi(monthsStr); err == nil && m > 0 {
months = m
}
}
stats, err := h.storage.GetBucketMonthlyHistory(bucket, months)
if err != nil {
log.Printf("Failed to get bucket history: %v", err)
http.Error(w, "Failed to fetch statistics", http.StatusInternalServerError)
return
}
response := h.formatMonthlyStats(stats)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// formatMonthlyStats 格式化月度统计数据
func (h *StatsHandler) formatMonthlyStats(stats []storage.BucketMonthlyStats) []MonthlyStatsResponse {
result := make([]MonthlyStatsResponse, 0, len(stats))
for _, stat := range stats {
result = append(result, MonthlyStatsResponse{
Year: stat.Year,
Month: stat.Month,
Bucket: stat.BucketName,
Stats: BucketOperationCounts{
OperationCountA: stat.OperationCountA,
OperationCountB: stat.OperationCountB,
Total: stat.OperationCountA + stat.OperationCountB,
},
})
}
return result
}
// ArchiveCurrentMonth 手动触发归档当前月份管理API
func (h *StatsHandler) ArchiveCurrentMonth(w http.ResponseWriter, r *http.Request) {
now := time.Now()
year, month := now.Year(), int(now.Month())
if err := h.storage.ArchiveMonthlyStats(year, month); err != nil {
log.Printf("Failed to archive monthly stats: %v", err)
http.Error(w, "Failed to archive statistics", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"status": "success",
"message": "Monthly statistics archived successfully",
"year": strconv.Itoa(year),
"month": strconv.Itoa(month),
})
}

View File

@@ -10,20 +10,34 @@ import (
"github.com/DullJZ/s3-balance/internal/config"
"github.com/DullJZ/s3-balance/internal/health"
"github.com/DullJZ/s3-balance/internal/metrics"
"github.com/DullJZ/s3-balance/internal/storage"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
)
// OperationCategory 表示后端操作分类
type OperationCategory string
const (
// OperationTypeA 表示写入类操作
OperationTypeA OperationCategory = "A"
// OperationTypeB 表示读取类操作
OperationTypeB OperationCategory = "B"
)
// BucketInfo 存储桶信息
type BucketInfo struct {
Config config.BucketConfig
Client *s3.Client
UsedSize int64 // 已使用容量(字节)
Available bool // 是否可用由health监控更新
LastChecked time.Time // 最后检查时间由health监控更新
mu sync.RWMutex
Config config.BucketConfig
Client *s3.Client
UsedSize int64 // 已使用容量(字节)
Available bool // 是否可用由health监控更新
LastChecked time.Time // 最后检查时间由health监控更新
mu sync.RWMutex
operationCountA int64
operationCountB int64
operationLimitReached bool
}
// Manager 存储桶管理器
@@ -36,15 +50,17 @@ type Manager struct {
healthMonitor *health.Monitor
statsMonitor *health.StatsMonitor
monitorCtx context.Context
storage *storage.Service
}
// NewManager 创建新的存储桶管理器
func NewManager(cfg *config.Config, metrics *metrics.Metrics) (*Manager, error) {
func NewManager(cfg *config.Config, metrics *metrics.Metrics, storageService *storage.Service) (*Manager, error) {
m := &Manager{
buckets: make(map[string]*BucketInfo),
config: cfg,
stopChan: make(chan struct{}),
metrics: metrics,
storage: storageService,
}
// 初始化所有存储桶客户端
@@ -71,9 +87,49 @@ func NewManager(cfg *config.Config, metrics *metrics.Metrics) (*Manager, error)
// 初始化健康监控
m.initHealthMonitoring()
// 加载持久化的操作计数
m.loadOperationCounts()
return m, nil
}
func (m *Manager) loadOperationCounts() {
if m.storage == nil {
return
}
counts, err := m.storage.GetBucketOperationCounts()
if err != nil {
log.Printf("Failed to load bucket operation counts: %v", err)
return
}
m.mu.RLock()
buckets := make(map[string]*BucketInfo, len(m.buckets))
for name, info := range m.buckets {
buckets[name] = info
}
m.mu.RUnlock()
for name, info := range buckets {
if info == nil {
continue
}
oc, ok := counts[name]
if !ok {
continue
}
if info.SetOperationCount(OperationTypeA, oc.CountA) {
log.Printf("Bucket %s disabled after exceeding A-type operation limit (persisted)", name)
}
if info.SetOperationCount(OperationTypeB, oc.CountB) {
log.Printf("Bucket %s disabled after exceeding B-type operation limit (persisted)", name)
}
}
}
// createS3Client 创建S3客户端
func createS3Client(bucketCfg config.BucketConfig) (*s3.Client, error) {
// 创建自定义端点解析器
@@ -128,12 +184,16 @@ func (m *Manager) initHealthMonitoring() {
// 创建S3健康检查器
healthChecker := health.NewS3Checker(healthConfig)
// 设置操作记录器以统计健康检查的 ListObjects 操作
healthChecker.SetOperationRecorder(reporter)
// 创建健康监控器
m.healthMonitor = health.NewMonitor(healthChecker, reporter)
// 创建统计收集器
statsCollector := health.NewS3StatsCollector(30 * time.Second)
// 设置操作记录器以统计 Stats 收集的 ListObjects 操作
statsCollector.SetOperationRecorder(reporter)
// 创建统计监控器
m.statsMonitor = health.NewStatsMonitor(
@@ -267,6 +327,89 @@ func (b *BucketInfo) UpdateUsedSize(delta int64) {
b.UsedSize += delta
}
// RecordOperation 记录一次后端操作并根据配置判断是否需要禁用存储桶
func (b *BucketInfo) RecordOperation(category OperationCategory) bool {
if b == nil {
return false
}
b.mu.Lock()
defer b.mu.Unlock()
if b.Config.Virtual {
return false
}
var (
limit int64
count *int64
)
switch category {
case OperationTypeA:
b.operationCountA++
count = &b.operationCountA
limit = int64(b.Config.OperationLimits.TypeA)
case OperationTypeB:
b.operationCountB++
count = &b.operationCountB
limit = int64(b.Config.OperationLimits.TypeB)
default:
return false
}
if limit <= 0 || count == nil {
return false
}
if !b.operationLimitReached && *count >= limit {
b.Available = false
b.operationLimitReached = true
return true
}
return false
}
// SetOperationCount 设置指定类别的操作计数并检查上限
func (b *BucketInfo) SetOperationCount(category OperationCategory, value int64) bool {
if b == nil {
return false
}
b.mu.Lock()
defer b.mu.Unlock()
if b.Config.Virtual {
return false
}
var limit int64
switch category {
case OperationTypeA:
b.operationCountA = value
limit = int64(b.Config.OperationLimits.TypeA)
case OperationTypeB:
b.operationCountB = value
limit = int64(b.Config.OperationLimits.TypeB)
default:
return false
}
if limit <= 0 {
return false
}
if !b.operationLimitReached && value >= limit {
b.Available = false
b.operationLimitReached = true
return true
}
return false
}
// IsVirtual 检查是否为虚拟存储桶
func (b *BucketInfo) IsVirtual() bool {
b.mu.RLock()
@@ -274,6 +417,25 @@ func (b *BucketInfo) IsVirtual() bool {
return b.Config.Virtual
}
// GetOperationCount 获取指定类型的操作计数
func (b *BucketInfo) GetOperationCount(category OperationCategory) int64 {
if b == nil {
return 0
}
b.mu.RLock()
defer b.mu.RUnlock()
switch category {
case OperationTypeA:
return b.operationCountA
case OperationTypeB:
return b.operationCountB
default:
return 0
}
}
// GetVirtualBuckets 获取所有虚拟存储桶
func (m *Manager) GetVirtualBuckets() []*BucketInfo {
m.mu.RLock()
@@ -371,6 +533,8 @@ func (m *Manager) UpdateConfig(newConfig *config.Config) error {
m.mu.Unlock()
m.loadOperationCounts()
if restartMonitors {
m.startMonitors()
}

View File

@@ -1,15 +1,17 @@
package bucket
import (
"log"
"github.com/DullJZ/s3-balance/internal/health"
"github.com/DullJZ/s3-balance/internal/metrics"
)
// MetricsReporter 实现 health.HealthReporter 和 health.StatsReporter 接口
type MetricsReporter struct {
metrics *metrics.Metrics
buckets map[string]*BucketInfo
manager *Manager
metrics *metrics.Metrics
buckets map[string]*BucketInfo
manager *Manager
}
// NewMetricsReporter 创建指标报告器
@@ -25,18 +27,20 @@ func (r *MetricsReporter) ReportHealth(targetID string, status health.Status) {
if r.metrics == nil {
return
}
// 更新存储桶可用性状态
r.manager.mu.RLock()
bucket, exists := r.manager.buckets[targetID]
r.manager.mu.RUnlock()
if exists {
bucket.mu.Lock()
bucket.Available = status.Healthy
if !bucket.operationLimitReached {
bucket.Available = status.Healthy
}
bucket.LastChecked = status.LastChecked
bucket.mu.Unlock()
// 更新 Prometheus 指标
r.metrics.SetBucketHealthy(targetID, bucket.Config.Endpoint, status.Healthy)
}
@@ -47,18 +51,68 @@ func (r *MetricsReporter) ReportStats(stats *health.Stats) {
if r.metrics == nil {
return
}
// 更新存储桶使用统计
r.manager.mu.RLock()
bucket, exists := r.manager.buckets[stats.TargetID]
r.manager.mu.RUnlock()
if exists {
bucket.mu.Lock()
bucket.UsedSize = stats.UsedSize
bucket.mu.Unlock()
// 更新 Prometheus 指标
r.metrics.SetBucketUsage(stats.TargetID, stats.UsedSize, bucket.Config.MaxSizeBytes)
}
}
}
// RecordOperation 实现 health.OperationRecorder 接口
func (r *MetricsReporter) RecordOperation(targetID string, category health.OperationCategory) {
r.manager.mu.RLock()
bucket, exists := r.manager.buckets[targetID]
storage := r.manager.storage
r.manager.mu.RUnlock()
if !exists {
return
}
// 转换 health.OperationCategory 到 bucket.OperationCategory
var bucketCategory OperationCategory
switch category {
case health.OperationTypeA:
bucketCategory = OperationTypeA
case health.OperationTypeB:
bucketCategory = OperationTypeB
default:
return
}
// 更新 Prometheus 指标
if r.metrics != nil {
r.metrics.RecordBackendOperation(targetID, string(bucketCategory))
}
// 持久化操作计数到数据库并更新内存计数
var disabled bool
if storage != nil {
// 先持久化到数据库
newCount, err := storage.IncrementBucketOperation(targetID, string(bucketCategory))
if err != nil {
log.Printf("Failed to persist health check operation count for bucket %s: %v", targetID, err)
// 如果数据库更新失败,仍然更新内存计数
disabled = bucket.RecordOperation(bucketCategory)
} else {
// 使用数据库返回的最新计数更新内存
disabled = bucket.SetOperationCount(bucketCategory, newCount)
}
} else {
// 没有 storage service只更新内存
disabled = bucket.RecordOperation(bucketCategory)
}
if disabled {
log.Printf("Bucket %s disabled after exceeding %s-type operation limit (detected by health check)", targetID, bucketCategory)
}
}

View File

@@ -16,6 +16,7 @@ type Config struct {
Balancer BalancerConfig `yaml:"balancer"`
Metrics MetricsConfig `yaml:"metrics"`
S3API S3APIConfig `yaml:"s3api"`
API APIConfig `yaml:"api"`
}
// ServerConfig 服务器配置
@@ -29,17 +30,24 @@ type ServerConfig struct {
// BucketConfig S3存储桶配置
type BucketConfig struct {
Name string `yaml:"name"` // 桶名称
Endpoint string `yaml:"endpoint"` // S3端点
Region string `yaml:"region"` // 区域
AccessKeyID string `yaml:"access_key_id"` // 访问密钥ID
SecretAccessKey string `yaml:"secret_access_key"` // 访问密钥
MaxSize string `yaml:"max_size"` // 最大容量 (例如: "10GB")
MaxSizeBytes int64 `yaml:"-"` // 内部使用,字节为单位
Weight int `yaml:"weight"` // 权重 (用于负载均衡)
Enabled bool `yaml:"enabled"` // 是否启用
PathStyle bool `yaml:"path_style"` // 是否使用路径风格访问
Virtual bool `yaml:"virtual"` // 是否为虚拟存储桶仅S3 API中可见
Name string `yaml:"name"` // 桶名称
Endpoint string `yaml:"endpoint"` // S3端点
Region string `yaml:"region"` // 区域
AccessKeyID string `yaml:"access_key_id"` // 访问密钥ID
SecretAccessKey string `yaml:"secret_access_key"` // 访问密钥
MaxSize string `yaml:"max_size"` // 最大容量 (例如: "10GB")
MaxSizeBytes int64 `yaml:"-"` // 内部使用,字节为单位
Weight int `yaml:"weight"` // 权重 (用于负载均衡)
Enabled bool `yaml:"enabled"` // 是否启用
PathStyle bool `yaml:"path_style"` // 是否使用路径风格访问
Virtual bool `yaml:"virtual"` // 是否为虚拟存储桶仅S3 API中可见
OperationLimits OperationLimitConfig `yaml:"operation_limits"`
}
// OperationLimitConfig 后端操作次数限制配置
type OperationLimitConfig struct {
TypeA int `yaml:"type_a"` // 类型A操作上限0表示不限制
TypeB int `yaml:"type_b"` // 类型B操作上限0表示不限制
}
// BalancerConfig 负载均衡配置
@@ -68,6 +76,12 @@ type S3APIConfig struct {
Host string `yaml:"host"` // 用于签名验证的Host为空则使用请求的Host
}
// APIConfig 管理API配置
type APIConfig struct {
Enabled bool `yaml:"enabled"` // 是否启用管理API
Token string `yaml:"token"` // API访问令牌
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Type string `yaml:"type"` // 数据库类型: sqlite, mysql, postgres
@@ -172,6 +186,11 @@ func (c *Config) SetDefaults() {
if c.S3API.SecretKey == "" {
c.S3API.SecretKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
}
// 管理API默认值
if c.API.Token == "" {
c.API.Token = "your-secure-api-token-here"
}
}
// ParseMaxSize 解析最大容量字符串为字节

View File

@@ -1,12 +1,15 @@
package config
import (
"bytes"
"fmt"
"log"
"os"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"gopkg.in/yaml.v3"
)
// Manager 配置管理器,支持热更新
@@ -126,7 +129,7 @@ func (m *Manager) watchConfig() {
// 只处理修改和重命名事件
if event.Op&fsnotify.Write == fsnotify.Write ||
event.Op&fsnotify.Rename == fsnotify.Rename {
event.Op&fsnotify.Rename == fsnotify.Rename {
log.Printf("Config file %s modified (detected by fsnotify), reloading...", m.configFile)
// 更新最后修改时间以避免轮询重复触发
@@ -227,6 +230,178 @@ func (m *Manager) logConfigChanges(oldConfig, newConfig *Config) {
}
}
// UpdateConfig 通过 API 更新配置文件
// 返回错误如果验证失败或写入失败
func (m *Manager) UpdateConfig(newConfig *Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// 1. 验证新配置
if err := m.validateConfig(newConfig); err != nil {
return err
}
// 2. 备份当前配置文件
if err := m.backupConfigFile(); err != nil {
log.Printf("Failed to backup config file: %v", err)
// 继续执行,备份失败不应阻止更新
}
// 3. 将新配置写入文件
if err := m.writeConfigFile(newConfig); err != nil {
return err
}
// 4. 更新内存中的配置
oldConfig := m.config
m.config = newConfig
// 5. 更新最后修改时间,避免文件监听重复触发
if fileInfo, err := os.Stat(m.configFile); err == nil {
m.lastModTime = fileInfo.ModTime()
}
log.Printf("Configuration updated successfully via API")
// 6. 触发配置变更回调(在锁外执行)
callbacks := make([]func(*Config), len(m.callbacks))
copy(callbacks, m.callbacks)
go func() {
for _, callback := range callbacks {
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Config change callback panic: %v", r)
}
}()
callback(newConfig)
}()
}
}()
// 7. 记录配置变更
m.logConfigChanges(oldConfig, newConfig)
return nil
}
// validateConfig 验证配置的有效性
func (m *Manager) validateConfig(cfg *Config) error {
// 基本验证
if cfg.Server.Port <= 0 || cfg.Server.Port > 65535 {
return fmt.Errorf("invalid server port: %d", cfg.Server.Port)
}
if len(cfg.Buckets) == 0 {
return fmt.Errorf("at least one bucket is required")
}
// 验证存储桶配置
for i, bucket := range cfg.Buckets {
if bucket.Name == "" {
return fmt.Errorf("bucket[%d]: name is required", i)
}
// 虚拟存储桶不需要端点和凭据
if !bucket.Virtual {
if bucket.Endpoint == "" {
return fmt.Errorf("bucket[%d] (%s): endpoint is required for non-virtual bucket", i, bucket.Name)
}
if bucket.AccessKeyID == "" {
return fmt.Errorf("bucket[%d] (%s): access_key_id is required for non-virtual bucket", i, bucket.Name)
}
if bucket.SecretAccessKey == "" {
return fmt.Errorf("bucket[%d] (%s): secret_access_key is required for non-virtual bucket", i, bucket.Name)
}
}
// 解析并验证容量大小
if err := cfg.Buckets[i].ParseMaxSize(); err != nil {
return fmt.Errorf("bucket[%d] (%s): invalid max_size: %w", i, bucket.Name, err)
}
}
// 验证负载均衡策略
validStrategies := map[string]bool{
"round-robin": true,
"least-space": true,
"weighted": true,
}
if !validStrategies[cfg.Balancer.Strategy] {
return fmt.Errorf("invalid balancer strategy: %s (must be one of: round-robin, least-space, weighted)", cfg.Balancer.Strategy)
}
// 验证数据库配置
if cfg.Database.Type == "" {
return fmt.Errorf("database type is required")
}
validDBTypes := map[string]bool{
"sqlite": true,
"mysql": true,
"postgres": true,
}
if !validDBTypes[cfg.Database.Type] {
return fmt.Errorf("invalid database type: %s (must be one of: sqlite, mysql, postgres)", cfg.Database.Type)
}
return nil
}
// backupConfigFile 备份当前配置文件
func (m *Manager) backupConfigFile() error {
backupPath := m.configFile + ".backup." + time.Now().Format("20060102-150405")
sourceData, err := os.ReadFile(m.configFile)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
if err := os.WriteFile(backupPath, sourceData, 0644); err != nil {
return fmt.Errorf("failed to write backup file: %w", err)
}
log.Printf("Config file backed up to: %s", backupPath)
return nil
}
// writeConfigFile 将配置写入 YAML 文件
func (m *Manager) writeConfigFile(cfg *Config) error {
// 先编码到缓冲区,避免在写入过程中损坏原文件
var buf bytes.Buffer
encoder := yaml.NewEncoder(&buf)
encoder.SetIndent(2)
if err := encoder.Encode(cfg); err != nil {
return fmt.Errorf("failed to encode config: %w", err)
}
if err := encoder.Close(); err != nil {
return fmt.Errorf("failed to close encoder: %w", err)
}
file, err := os.OpenFile(m.configFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("failed to open config file: %w", err)
}
if _, err := file.Write(buf.Bytes()); err != nil {
file.Close()
return fmt.Errorf("failed to write config file: %w", err)
}
if err := file.Sync(); err != nil {
file.Close()
return fmt.Errorf("failed to sync config file: %w", err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("failed to close config file: %w", err)
}
return nil
}
// Close 关闭配置管理器
func (m *Manager) Close() error {
// 停止监听协程
@@ -243,4 +418,4 @@ func (m *Manager) Close() error {
}
return nil
}
}

View File

@@ -181,6 +181,7 @@ func AutoMigrate() error {
models := []interface{}{
&storage.Object{},
&storage.BucketStats{},
&storage.BucketMonthlyStats{},
&storage.UploadSession{},
&storage.AccessLog{},
&storage.VirtualBucketMapping{},

View File

@@ -34,7 +34,8 @@ func (t *S3Target) GetEndpoint() string {
// S3Checker S3健康检查器
type S3Checker struct {
config Config
config Config
opRecorder OperationRecorder
}
// NewS3Checker 创建S3健康检查器
@@ -54,6 +55,11 @@ func NewS3Checker(config Config) *S3Checker {
}
}
// SetOperationRecorder 设置操作记录器
func (c *S3Checker) SetOperationRecorder(recorder OperationRecorder) {
c.opRecorder = recorder
}
// Check 执行S3健康检查
func (c *S3Checker) Check(ctx context.Context, target Target) Status {
s3Target, ok := target.(*S3Target)
@@ -121,6 +127,12 @@ func (c *S3Checker) performSimpleCheck(ctx context.Context, target *S3Target) er
Bucket: aws.String(target.Bucket),
MaxKeys: aws.Int32(1),
})
// 记录操作ListObjectsV2 是 Class A 操作)
if c.opRecorder != nil {
c.opRecorder.RecordOperation(target.GetID(), OperationTypeA)
}
return err
}

View File

@@ -26,7 +26,8 @@ type Stats struct {
// S3StatsCollector S3统计信息收集器
type S3StatsCollector struct {
timeout time.Duration
timeout time.Duration
opRecorder OperationRecorder
}
// NewS3StatsCollector 创建S3统计信息收集器
@@ -39,6 +40,11 @@ func NewS3StatsCollector(timeout time.Duration) *S3StatsCollector {
}
}
// SetOperationRecorder 设置操作记录器
func (c *S3StatsCollector) SetOperationRecorder(recorder OperationRecorder) {
c.opRecorder = recorder
}
// CollectStats 收集S3存储桶统计信息
func (c *S3StatsCollector) CollectStats(ctx context.Context, target Target) (*Stats, error) {
s3Target, ok := target.(*S3Target)
@@ -58,6 +64,12 @@ func (c *S3StatsCollector) CollectStats(ctx context.Context, target Target) (*St
Bucket: aws.String(s3Target.Bucket),
ContinuationToken: continuationToken,
})
// 记录操作(每次 ListObjectsV2 调用都是 Class A 操作)
if c.opRecorder != nil {
c.opRecorder.RecordOperation(s3Target.GetID(), OperationTypeA)
}
if err != nil {
return nil, err
}

View File

@@ -66,3 +66,19 @@ type HealthReporter interface {
// ReportHealth 报告健康状态
ReportHealth(targetID string, status Status)
}
// OperationCategory 操作分类
type OperationCategory string
const (
// OperationTypeA 写入类操作 (ListObjects, PutObject, etc.)
OperationTypeA OperationCategory = "A"
// OperationTypeB 读取类操作 (GetObject)
OperationTypeB OperationCategory = "B"
)
// OperationRecorder 操作记录器接口
type OperationRecorder interface {
// RecordOperation 记录一次后端操作
RecordOperation(targetID string, category OperationCategory)
}

View File

@@ -36,6 +36,11 @@ var (
Name: "s3_balance_balancer_decisions_total",
Help: "Total number of load balancing decisions",
}, []string{"strategy", "bucket"})
backendOperationsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "s3_balance_backend_operations_total",
Help: "Total number of backend bucket operations by category",
}, []string{"bucket", "category"})
)
type Metrics struct{}
@@ -67,4 +72,8 @@ func (m *Metrics) RecordS3OperationDuration(operation, bucket string, duration f
func (m *Metrics) RecordBalancerDecision(strategy, bucket string) {
balancerDecisions.WithLabelValues(strategy, bucket).Inc()
}
}
func (m *Metrics) RecordBackendOperation(bucket, category string) {
backendOperationsTotal.WithLabelValues(bucket, category).Inc()
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"strings"
"github.com/DullJZ/s3-validate/pkg/s3validate"
)
@@ -76,3 +77,36 @@ func invokeOnError(w http.ResponseWriter, r *http.Request, cfg S3SignatureConfig
}
http.Error(w, message, http.StatusForbidden)
}
// TokenAuthMiddleware 创建Token认证中间件用于管理API
func TokenAuthMiddleware(validToken string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 从Authorization头中提取token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
w.Header().Set("Content-Type", "application/json")
http.Error(w, `{"error": "missing authorization header"}`, http.StatusUnauthorized)
return
}
// 支持两种格式:
// 1. Bearer <token>
// 2. <token>
token := authHeader
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
// 验证token
if token != validToken {
w.Header().Set("Content-Type", "application/json")
http.Error(w, `{"error": "invalid token"}`, http.StatusUnauthorized)
return
}
// 继续处理请求
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,96 @@
package scheduler
import (
"log"
"time"
"github.com/DullJZ/s3-balance/internal/storage"
)
// MonthlyArchiver 月度统计归档器
type MonthlyArchiver struct {
storage *storage.Service
ticker *time.Ticker
stopChan chan struct{}
lastArchivedDate string // 格式: "2025-01" - 记录上次归档的月份
}
// NewMonthlyArchiver 创建月度归档器
func NewMonthlyArchiver(storage *storage.Service, checkInterval time.Duration) *MonthlyArchiver {
return &MonthlyArchiver{
storage: storage,
ticker: time.NewTicker(checkInterval),
stopChan: make(chan struct{}),
}
}
// Start 启动月度归档定期任务
func (m *MonthlyArchiver) Start() {
log.Println("Starting monthly statistics archiver...")
// 启动时立即归档上个月的数据(如果还没有归档)
m.archiveLastMonth()
go func() {
for {
select {
case <-m.ticker.C:
m.checkAndArchive()
case <-m.stopChan:
log.Println("Monthly statistics archiver stopped")
return
}
}
}()
}
// Stop 停止归档任务
func (m *MonthlyArchiver) Stop() {
close(m.stopChan)
m.ticker.Stop()
}
// checkAndArchive 检查并归档统计数据
func (m *MonthlyArchiver) checkAndArchive() {
now := time.Now()
lastMonth := now.AddDate(0, -1, 0)
lastMonthKey := lastMonth.Format("2006-01")
// 如果是每月的第一天,且上个月还未归档,则归档上个月的数据
if now.Day() == 1 && m.lastArchivedDate != lastMonthKey {
m.archiveLastMonth()
m.lastArchivedDate = lastMonthKey
}
// 每天都归档当前月份(实时更新)
m.archiveCurrentMonth()
}
// archiveLastMonth 归档上个月的数据
func (m *MonthlyArchiver) archiveLastMonth() {
now := time.Now()
lastMonth := now.AddDate(0, -1, 0)
year, month := lastMonth.Year(), int(lastMonth.Month())
log.Printf("Archiving monthly stats for %d-%02d...", year, month)
if err := m.storage.ArchiveMonthlyStats(year, month); err != nil {
log.Printf("Failed to archive monthly stats for %d-%02d: %v", year, month, err)
return
}
log.Printf("Successfully archived monthly stats for %d-%02d", year, month)
}
// archiveCurrentMonth 归档当前月份(实时更新)
func (m *MonthlyArchiver) archiveCurrentMonth() {
now := time.Now()
year, month := now.Year(), int(now.Month())
if err := m.storage.ArchiveMonthlyStats(year, month); err != nil {
log.Printf("Failed to update current month stats for %d-%02d: %v", year, month, err)
return
}
log.Printf("Updated current month stats for %d-%02d", year, month)
}

View File

@@ -29,13 +29,15 @@ func (Object) TableName() string {
// BucketStats 存储桶统计信息模型
type BucketStats struct {
ID uint `gorm:"primaryKey" json:"id"`
BucketName string `gorm:"uniqueIndex;size:255;not null" json:"bucket_name"`
ObjectCount int64 `gorm:"not null;default:0" json:"object_count"`
TotalSize int64 `gorm:"not null;default:0" json:"total_size"`
LastCheckedAt time.Time `json:"last_checked_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID uint `gorm:"primaryKey" json:"id"`
BucketName string `gorm:"uniqueIndex;size:255;not null" json:"bucket_name"`
ObjectCount int64 `gorm:"not null;default:0" json:"object_count"`
TotalSize int64 `gorm:"not null;default:0" json:"total_size"`
OperationCountA int64 `gorm:"not null;default:0" json:"operation_count_a"`
OperationCountB int64 `gorm:"not null;default:0" json:"operation_count_b"`
LastCheckedAt time.Time `json:"last_checked_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName 指定表名
@@ -43,6 +45,23 @@ func (BucketStats) TableName() string {
return "bucket_stats"
}
// BucketMonthlyStats 存储桶月度统计信息模型
type BucketMonthlyStats struct {
ID uint `gorm:"primaryKey" json:"id"`
BucketName string `gorm:"uniqueIndex:idx_bucket_month;size:255;not null" json:"bucket_name"`
Year int `gorm:"uniqueIndex:idx_bucket_month;not null" json:"year"`
Month int `gorm:"uniqueIndex:idx_bucket_month;not null" json:"month"`
OperationCountA int64 `gorm:"not null;default:0" json:"operation_count_a"`
OperationCountB int64 `gorm:"not null;default:0" json:"operation_count_b"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName 指定表名
func (BucketMonthlyStats) TableName() string {
return "bucket_monthly_stats"
}
// VirtualBucketMapping 虚拟存储桶文件级映射模型
type VirtualBucketMapping struct {
ID uint `gorm:"primaryKey" json:"id"`

View File

@@ -1,6 +1,7 @@
package storage
import (
"errors"
"fmt"
"time"
@@ -12,6 +13,12 @@ type Service struct {
db *gorm.DB
}
// OperationCounts 后端操作计数
type OperationCounts struct {
CountA int64
CountB int64
}
// NewService 创建新的存储服务
func NewService(db *gorm.DB) *Service {
return &Service{
@@ -254,6 +261,95 @@ func (s *Service) updateBucketStats(bucketName string) error {
return nil
}
func (s *Service) ensureBucketStats(bucketName string) (*BucketStats, error) {
if bucketName == "" {
return nil, fmt.Errorf("bucket name cannot be empty")
}
stats := &BucketStats{}
err := s.db.Where("bucket_name = ?", bucketName).First(stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
stats = &BucketStats{BucketName: bucketName, LastCheckedAt: time.Now()}
if createErr := s.db.Create(stats).Error; createErr != nil {
// 如果在并发创建下出现重复键,忽略并再次查询
if errors.Is(createErr, gorm.ErrDuplicatedKey) {
if retryErr := s.db.Where("bucket_name = ?", bucketName).First(stats).Error; retryErr != nil {
return nil, fmt.Errorf("failed to fetch bucket stats after duplicate: %w", retryErr)
}
} else {
return nil, fmt.Errorf("failed to create bucket stats: %w", createErr)
}
}
return stats, nil
} else if err != nil {
return nil, fmt.Errorf("failed to fetch bucket stats: %w", err)
}
return stats, nil
}
// IncrementBucketOperation 增加指定存储桶的操作计数
func (s *Service) IncrementBucketOperation(bucketName, category string) (int64, error) {
if _, err := s.ensureBucketStats(bucketName); err != nil {
return 0, err
}
var field string
switch category {
case "A":
field = "operation_count_a"
case "B":
field = "operation_count_b"
default:
return 0, fmt.Errorf("unknown operation category: %s", category)
}
// 使用事务确保原子性
var count int64
err := s.db.Transaction(func(tx *gorm.DB) error {
// 原子递增
if err := tx.Model(&BucketStats{}).
Where("bucket_name = ?", bucketName).
UpdateColumn(field, gorm.Expr(field+" + ?", 1)).Error; err != nil {
return fmt.Errorf("failed to increment %s for bucket %s: %w", field, bucketName, err)
}
// 在同一事务中读取最新值
if err := tx.Model(&BucketStats{}).
Where("bucket_name = ?", bucketName).
Select(field).
Scan(&count).Error; err != nil {
return fmt.Errorf("failed to fetch updated %s for bucket %s: %w", field, bucketName, err)
}
return nil
})
if err != nil {
return 0, err
}
return count, nil
}
// GetBucketOperationCounts 获取所有存储桶的操作计数
func (s *Service) GetBucketOperationCounts() (map[string]OperationCounts, error) {
var stats []BucketStats
if err := s.db.Find(&stats).Error; err != nil {
return nil, fmt.Errorf("failed to list bucket stats: %w", err)
}
result := make(map[string]OperationCounts, len(stats))
for _, st := range stats {
result[st.BucketName] = OperationCounts{
CountA: st.OperationCountA,
CountB: st.OperationCountB,
}
}
return result, nil
}
// RecordUploadSession 记录上传会话
func (s *Service) RecordUploadSession(uploadID, key, bucketName string, size int64) error {
session := &UploadSession{
@@ -545,3 +641,175 @@ func (s *Service) DeleteVirtualBucketFileMapping(virtualBucketName, objectKey st
}
return nil
}
// ArchiveMonthlyStats 归档指定月份的统计数据(存储增量值,非累计值)
// 如果该月份的记录已存在,则更新;否则创建新记录
func (s *Service) ArchiveMonthlyStats(year, month int) error {
// 获取当前所有bucket的累计统计
var currentStats []BucketStats
if err := s.db.Find(&currentStats).Error; err != nil {
return fmt.Errorf("failed to fetch bucket stats: %w", err)
}
// 获取上个月的累计值(从上月归档数据推算)
lastYear, lastMonth := year, month-1
if lastMonth == 0 {
lastMonth = 12
lastYear--
}
// 查询上个月及之前的所有归档数据,用于推算上月末的累计值
var lastMonthArchived []BucketMonthlyStats
lastMonthMap := make(map[string]int64) // bucket_name -> last_month_cumulative_a
lastMonthMapB := make(map[string]int64) // bucket_name -> last_month_cumulative_b
if err := s.db.Where("year < ? OR (year = ? AND month <= ?)", lastYear, lastYear, lastMonth).
Order("year ASC, month ASC").
Find(&lastMonthArchived).Error; err == nil {
// 累加历史增量得到上月末累计值
cumulativeA := make(map[string]int64)
cumulativeB := make(map[string]int64)
for _, archived := range lastMonthArchived {
cumulativeA[archived.BucketName] += archived.OperationCountA
cumulativeB[archived.BucketName] += archived.OperationCountB
}
lastMonthMap = cumulativeA
lastMonthMapB = cumulativeB
}
// 对每个bucket计算本月增量并存储
for _, stat := range currentStats {
lastCumulativeA := lastMonthMap[stat.BucketName]
lastCumulativeB := lastMonthMapB[stat.BucketName]
incrementA := stat.OperationCountA - lastCumulativeA
incrementB := stat.OperationCountB - lastCumulativeB
// 如果是首次运行没有历史数据incrementA/B 可能等于累计值
// 这是预期行为首月记录的就是从0到当前的增量
// 边界情况如果计算出负值说明数据不一致设置为0
if incrementA < 0 {
incrementA = 0
}
if incrementB < 0 {
incrementB = 0
}
monthlyStats := BucketMonthlyStats{
BucketName: stat.BucketName,
Year: year,
Month: month,
OperationCountA: incrementA,
OperationCountB: incrementB,
}
// 使用 UPSERT 逻辑:如果存在则更新,否则创建
if err := s.db.Where("bucket_name = ? AND year = ? AND month = ?",
stat.BucketName, year, month).
Assign(BucketMonthlyStats{
OperationCountA: incrementA,
OperationCountB: incrementB,
}).
FirstOrCreate(&monthlyStats).Error; err != nil {
return fmt.Errorf("failed to archive monthly stats for bucket %s: %w", stat.BucketName, err)
}
}
return nil
}
// GetMonthlyStats 获取指定月份的统计数据
func (s *Service) GetMonthlyStats(year, month int) ([]BucketMonthlyStats, error) {
var stats []BucketMonthlyStats
if err := s.db.Where("year = ? AND month = ?", year, month).
Find(&stats).Error; err != nil {
return nil, fmt.Errorf("failed to fetch monthly stats: %w", err)
}
return stats, nil
}
// GetMonthlyStatsRange 获取指定时间范围的统计数据
func (s *Service) GetMonthlyStatsRange(startYear, startMonth, endYear, endMonth int) ([]BucketMonthlyStats, error) {
var stats []BucketMonthlyStats
if err := s.db.Where("(year > ? OR (year = ? AND month >= ?)) AND (year < ? OR (year = ? AND month <= ?))",
startYear, startYear, startMonth, endYear, endYear, endMonth).
Order("year, month, bucket_name").
Find(&stats).Error; err != nil {
return nil, fmt.Errorf("failed to fetch monthly stats range: %w", err)
}
return stats, nil
}
// GetCurrentMonthStats 获取当前月份的实时统计(从 bucket_stats 计算增量)
func (s *Service) GetCurrentMonthStats() ([]BucketMonthlyStats, error) {
now := time.Now()
year, month := now.Year(), int(now.Month())
// 获取上个月末的累计值(通过累加所有历史增量)
lastYear, lastMonth := year, month-1
if lastMonth == 0 {
lastMonth = 12
lastYear--
}
var historicalStats []BucketMonthlyStats
lastMonthCumulativeA := make(map[string]int64)
lastMonthCumulativeB := make(map[string]int64)
if err := s.db.Where("year < ? OR (year = ? AND month <= ?)", lastYear, lastYear, lastMonth).
Find(&historicalStats).Error; err == nil {
// 累加所有历史增量得到上月末累计值
for _, stat := range historicalStats {
lastMonthCumulativeA[stat.BucketName] += stat.OperationCountA
lastMonthCumulativeB[stat.BucketName] += stat.OperationCountB
}
}
// 获取当前累计数据
var currentStats []BucketStats
if err := s.db.Find(&currentStats).Error; err != nil {
return nil, fmt.Errorf("failed to fetch current bucket stats: %w", err)
}
// 计算当前月份的增量
result := make([]BucketMonthlyStats, 0, len(currentStats))
for _, current := range currentStats {
incrementA := current.OperationCountA - lastMonthCumulativeA[current.BucketName]
incrementB := current.OperationCountB - lastMonthCumulativeB[current.BucketName]
// 边界情况如果计算出负值说明数据不一致设置为0
if incrementA < 0 {
incrementA = 0
}
if incrementB < 0 {
incrementB = 0
}
result = append(result, BucketMonthlyStats{
BucketName: current.BucketName,
Year: year,
Month: month,
OperationCountA: incrementA,
OperationCountB: incrementB,
UpdatedAt: time.Now(),
})
}
return result, nil
}
// GetBucketMonthlyHistory 获取指定存储桶的月度历史统计
func (s *Service) GetBucketMonthlyHistory(bucketName string, months int) ([]BucketMonthlyStats, error) {
var stats []BucketMonthlyStats
if err := s.db.Where("bucket_name = ?", bucketName).
Order("year DESC, month DESC").
Limit(months).
Find(&stats).Error; err != nil {
return nil, fmt.Errorf("failed to fetch bucket monthly history: %w", err)
}
return stats, nil
}

87
internal/web/handler.go Normal file
View File

@@ -0,0 +1,87 @@
package web
import (
"io"
"io/fs"
"net/http"
"path"
"strings"
)
// Handler Web管理界面处理器
type Handler struct {
fileSystem http.FileSystem
}
// NewHandler 创建Web处理器
// distFS 应该是通过 embed.FS 嵌入的 dist 目录
func NewHandler(distFS fs.FS) *Handler {
return &Handler{
fileSystem: http.FS(distFS),
}
}
// ServeHTTP 实现 http.Handler 接口
// 处理单页应用的路由,将所有未找到的路径重定向到 index.html
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 清理路径
p := r.URL.Path
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
// 尝试打开文件
f, err := h.fileSystem.Open(path.Clean(p))
if err != nil {
// 文件不存在,返回 index.html (用于支持前端路由)
indexFile, err := h.fileSystem.Open("index.html")
if err != nil {
http.Error(w, "File not found", http.StatusNotFound)
return
}
defer indexFile.Close()
// 读取 index.html 内容
stat, err := indexFile.Stat()
if err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
http.ServeContent(w, r, "index.html", stat.ModTime(), indexFile.(io.ReadSeeker))
return
}
defer f.Close()
// 文件存在,检查是否为目录
stat, err := f.Stat()
if err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
if stat.IsDir() {
// 如果是目录,尝试返回 index.html
indexPath := path.Join(p, "index.html")
indexFile, err := h.fileSystem.Open(indexPath)
if err != nil {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
defer indexFile.Close()
indexStat, err := indexFile.Stat()
if err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
http.ServeContent(w, r, "index.html", indexStat.ModTime(), indexFile.(io.ReadSeeker))
return
}
// 返回文件内容
http.ServeContent(w, r, stat.Name(), stat.ModTime(), f.(io.ReadSeeker))
}

14
internal/webui/embed.go Normal file
View File

@@ -0,0 +1,14 @@
package webui
import (
"embed"
"io/fs"
)
//go:embed dist
var distFS embed.FS
// GetDistFS 获取嵌入的前端静态文件系统
func GetDistFS() (fs.FS, error) {
return fs.Sub(distFS, "dist")
}