diff --git a/cmd/s3-balance/main.go b/cmd/s3-balance/main.go index e888730..f7287c0 100644 --- a/cmd/s3-balance/main.go +++ b/cmd/s3-balance/main.go @@ -65,6 +65,9 @@ func main() { // 创建存储服务 storageService := storage.NewService(database.GetDB()) + // 启动定期清理过期上传会话的任务 + startSessionCleaner(ctx, storageService) + // 创建S3兼容API处理器 s3Handler := api.NewS3Handler( bucketManager, @@ -180,3 +183,51 @@ func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } + +// startSessionCleaner 启动定期清理过期会话的任务 +func startSessionCleaner(ctx context.Context, storageService *storage.Service) { + go func() { + // 初始延迟,避免启动时立即执行 + time.Sleep(1 * time.Minute) + + // 每小时清理一次过期的上传会话 + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Println("Stopping session cleaner") + return + case <-ticker.C: + log.Println("Cleaning expired upload sessions...") + if err := storageService.CleanExpiredSessions(); err != nil { + log.Printf("Failed to clean expired sessions: %v", err) + } else { + log.Println("Successfully cleaned expired upload sessions") + } + + // 同时中止在S3存储桶中过期的分片上传 + cleanupS3MultipartUploads(ctx, storageService) + } + } + }() +} + +// cleanupS3MultipartUploads 清理S3存储桶中过期的分片上传 +func cleanupS3MultipartUploads(ctx context.Context, storageService *storage.Service) { + // 获取所有过期的会话 + sessions, err := storageService.GetPendingUploadSessions("", "", "", 0) + if err != nil { + log.Printf("Failed to get pending sessions for cleanup: %v", err) + return + } + + for _, session := range sessions { + if session.IsExpired() { + log.Printf("Found expired session: uploadID=%s, key=%s", session.UploadID, session.Key) + // 此处可以添加调用S3 AbortMultipartUpload的逻辑 + // 但需要知道对应的真实存储桶信息 + } + } +} diff --git a/internal/api/s3_handler.go b/internal/api/s3_handler.go index 54d8051..b5824ad 100644 --- a/internal/api/s3_handler.go +++ b/internal/api/s3_handler.go @@ -843,6 +843,18 @@ func (h *S3Handler) handleUploadPart(w http.ResponseWriter, r *http.Request) { if etag != "" { w.Header().Set("ETag", etag) } + + // 更新上传会话的分片数 + session, err := h.storage.GetUploadSession(uploadID) + if err != nil { + log.Printf("Failed to get upload session for uploadID %s: %v", uploadID, err) + } else { + // 更新已完成的分片数 + if err := h.storage.UpdateUploadSession(uploadID, session.CompletedParts+1, "pending"); err != nil { + log.Printf("Failed to update upload session for uploadID %s: %v", uploadID, err) + } + } + w.WriteHeader(http.StatusOK) } else { // 读取错误响应体以获取详细信息 @@ -899,11 +911,19 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request return } + uploadID := *createResp.UploadId + + // 记录上传会话到数据库 + if err := h.storage.RecordUploadSession(uploadID, key, targetBucket.Config.Name, 0, 0); err != nil { + log.Printf("Failed to record upload session for uploadID %s: %v", uploadID, err) + // 不影响主流程,继续处理 + } + result := InitiateMultipartUploadResult{ Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", Bucket: bucketName, // 返回虚拟存储桶名称给客户端 Key: key, - UploadID: *createResp.UploadId, + UploadID: uploadID, } h.sendXMLResponse(w, http.StatusOK, result) @@ -913,22 +933,124 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request func (h *S3Handler) handleListMultipartUploads(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) bucketName := vars["bucket"] - key := vars["key"] - // 检查bucket是否存在 - if _, ok := h.bucketManager.GetBucket(bucketName); !ok { + // 解析查询参数 + queryParams := r.URL.Query() + keyMarker := queryParams.Get("key-marker") + uploadIdMarker := queryParams.Get("upload-id-marker") + prefix := queryParams.Get("prefix") + delimiter := queryParams.Get("delimiter") + maxUploadsStr := queryParams.Get("max-uploads") + maxUploads := 1000 + if maxUploadsStr != "" { + if m, err := strconv.Atoi(maxUploadsStr); err == nil && m > 0 { + maxUploads = m + } + } + + // 检查请求的存储桶 + requestedBucket, ok := h.bucketManager.GetBucket(bucketName) + if !ok { h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) return } - // 简化实现:返回空列表 + var allUploads []Upload + var isTruncated bool + + // 如果是虚拟存储桶,从数据库查询上传会话 + if requestedBucket.IsVirtual() { + // 从数据库获取待处理的上传会话 + sessions, err := h.storage.GetPendingUploadSessions(prefix, keyMarker, uploadIdMarker, maxUploads) + if err != nil { + log.Printf("Failed to get pending upload sessions: %v", err) + // 降级到遍历所有存储桶的方式 + ctx := context.Background() + allBuckets := h.bucketManager.GetAllBuckets() + for _, bucket := range allBuckets { + if bucket.IsVirtual() { + continue + } + + // 列出每个真实存储桶的分片上传 + listResp, err := bucket.Client.ListMultipartUploads(ctx, &s3.ListMultipartUploadsInput{ + Bucket: aws.String(bucket.Config.Name), + KeyMarker: aws.String(keyMarker), + UploadIdMarker: aws.String(uploadIdMarker), + Prefix: aws.String(prefix), + Delimiter: aws.String(delimiter), + MaxUploads: aws.Int32(int32(maxUploads)), + }) + if err != nil { + log.Printf("Failed to list multipart uploads for bucket %s: %v", bucket.Config.Name, err) + continue + } + + // 将结果添加到列表中 + for _, upload := range listResp.Uploads { + allUploads = append(allUploads, Upload{ + Key: aws.ToString(upload.Key), + UploadID: aws.ToString(upload.UploadId), + Initiator: Owner{ + ID: aws.ToString(upload.Initiator.ID), + DisplayName: aws.ToString(upload.Initiator.DisplayName), + }, + Owner: Owner{ + ID: aws.ToString(upload.Owner.ID), + DisplayName: aws.ToString(upload.Owner.DisplayName), + }, + StorageClass: string(upload.StorageClass), + Initiated: aws.ToTime(upload.Initiated), + }) + } + } + } else { + // 成功从数据库获取会话 + if len(sessions) > maxUploads { + sessions = sessions[:maxUploads] + isTruncated = true + } + + // 转换会话为Upload格式 + for _, session := range sessions { + allUploads = append(allUploads, Upload{ + Key: session.Key, + UploadID: session.UploadID, + Initiator: Owner{ + ID: "s3-balance", + DisplayName: "S3 Balance User", + }, + Owner: Owner{ + ID: "s3-balance", + DisplayName: "S3 Balance User", + }, + StorageClass: "STANDARD", + Initiated: session.CreatedAt, + }) + } + } + } else { + // 如果不是虚拟存储桶,拒绝客户端对真实存储桶的直接操作 + h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) + return + } + + // 构建响应 result := ListMultipartUploadsResult{ Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", Bucket: bucketName, - KeyMarker: key, - MaxUploads: 1000, - IsTruncated: false, - Uploads: make([]Upload, 0), + KeyMarker: keyMarker, + UploadIdMarker: uploadIdMarker, + MaxUploads: maxUploads, + IsTruncated: isTruncated, + Uploads: allUploads, + } + + // 如果有更多结果,设置下一个标记 + if isTruncated && len(allUploads) > 0 { + lastUpload := allUploads[len(allUploads)-1] + result.NextKeyMarker = lastUpload.Key + result.NextUploadIdMarker = lastUpload.UploadID } h.sendXMLResponse(w, http.StatusOK, result) @@ -941,22 +1063,117 @@ func (h *S3Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Requ key := vars["key"] uploadID := r.URL.Query().Get("uploadId") - // 检查bucket是否存在 - if _, ok := h.bucketManager.GetBucket(bucketName); !ok { + // 解析查询参数 + queryParams := r.URL.Query() + partNumberMarkerStr := queryParams.Get("part-number-marker") + partNumberMarker := 0 + if partNumberMarkerStr != "" { + if m, err := strconv.Atoi(partNumberMarkerStr); err == nil && m > 0 { + partNumberMarker = m + } + } + maxPartsStr := queryParams.Get("max-parts") + maxParts := 1000 + if maxPartsStr != "" { + if m, err := strconv.Atoi(maxPartsStr); err == nil && m > 0 { + maxParts = m + } + } + + // 检查请求的存储桶 + requestedBucket, ok := h.bucketManager.GetBucket(bucketName) + if !ok { h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) return } - // 简化实现:返回空列表 + var targetBucket *bucket.BucketInfo + + // 如果是虚拟存储桶,需要通过映射查找真实存储桶 + if requestedBucket.IsVirtual() { + // 获取虚拟存储桶映射 + mapping, err := h.storage.GetVirtualBucketMapping(bucketName, key) + if err != nil { + // 如果没有找到映射,尝试查询所有真实存储桶 + allBuckets := h.bucketManager.GetAllBuckets() + for _, bucket := range allBuckets { + if bucket.IsVirtual() { + continue + } + // 尝试列出分片,如果成功则说明上传在这个桶中 + ctx := context.Background() + _, err := bucket.Client.ListParts(ctx, &s3.ListPartsInput{ + Bucket: aws.String(bucket.Config.Name), + Key: aws.String(key), + UploadId: aws.String(uploadID), + PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)), + MaxParts: aws.Int32(1), // 只检查是否存在 + }) + if err == nil { + targetBucket = bucket + break + } + } + if targetBucket == nil { + h.sendS3Error(w, "NoSuchUpload", "The specified multipart upload does not exist", uploadID) + return + } + } else { + // 获取映射到的真实存储桶 + targetBucket, ok = h.bucketManager.GetBucket(mapping.RealBucketName) + if !ok { + h.sendS3Error(w, "InternalError", "Mapped real bucket not found", key) + return + } + } + } else { + // 如果不是虚拟存储桶,拒绝客户端对真实存储桶的直接操作 + h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName) + return + } + + // 列出分片 + ctx := context.Background() + listResp, err := targetBucket.Client.ListParts(ctx, &s3.ListPartsInput{ + Bucket: aws.String(targetBucket.Config.Name), + Key: aws.String(key), + UploadId: aws.String(uploadID), + PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)), + MaxParts: aws.Int32(int32(maxParts)), + }) + if err != nil { + h.sendS3Error(w, "NoSuchUpload", "The specified multipart upload does not exist", uploadID) + return + } + + // 转换分片列表 + var parts []Part + for _, part := range listResp.Parts { + parts = append(parts, Part{ + PartNumber: int(aws.ToInt32(part.PartNumber)), + LastModified: aws.ToTime(part.LastModified), + ETag: aws.ToString(part.ETag), + Size: aws.ToInt64(part.Size), + }) + } + + // 构建响应 result := ListPartsResult{ Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/", - Bucket: bucketName, + Bucket: bucketName, // 返回虚拟存储桶名称给客户端 Key: key, UploadID: uploadID, - PartNumberMarker: 0, - MaxParts: 1000, - IsTruncated: false, - Parts: make([]Part, 0), + PartNumberMarker: partNumberMarker, + MaxParts: maxParts, + IsTruncated: aws.ToBool(listResp.IsTruncated), + Parts: parts, + } + + // 设置下一个分片标记 + if listResp.NextPartNumberMarker != nil { + if nextMarker, err := strconv.Atoi(aws.ToString(listResp.NextPartNumberMarker)); err == nil { + result.NextPartNumberMarker = nextMarker + } } h.sendXMLResponse(w, http.StatusOK, result) @@ -1002,7 +1219,18 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http // 解析请求体以获取分片列表 var completeReq CompleteMultipartUpload body, _ := io.ReadAll(r.Body) - xml.Unmarshal(body, &completeReq) + err := xml.Unmarshal(body, &completeReq) + if err != nil { + log.Printf("Failed to parse CompleteMultipartUpload request body: %v, body: %s", err, string(body)) + h.sendS3Error(w, "MalformedXML", "The XML you provided was not well-formed", key) + return + } + + log.Printf("CompleteMultipartUpload request - Bucket: %s, Key: %s, UploadID: %s, Parts: %d", + bucketName, key, uploadID, len(completeReq.Parts)) + for i, part := range completeReq.Parts { + log.Printf(" Part %d: PartNumber=%d, ETag=%s", i+1, part.PartNumber, part.ETag) + } // 完成分片上传 ctx := context.Background() @@ -1014,6 +1242,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http }) } + log.Printf("Calling CompleteMultipartUpload on real bucket %s with uploadID %s", targetBucket.Config.Name, uploadID) completeResp, err := targetBucket.Client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(targetBucket.Config.Name), Key: aws.String(key), @@ -1023,6 +1252,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http }, }) if err != nil { + log.Printf("CompleteMultipartUpload failed: %v", err) h.sendS3Error(w, "InternalError", "Failed to complete multipart upload", key) return } @@ -1057,6 +1287,12 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http targetBucket.UpdateUsedSize(objectSize) } + // 更新上传会话状态为已完成 + if err := h.storage.UpdateUploadSession(uploadID, len(completeReq.Parts), "completed"); err != nil { + log.Printf("Failed to update upload session status to completed for uploadID %s: %v", uploadID, err) + // 不影响主流程 + } + h.sendXMLResponse(w, http.StatusOK, result) } @@ -1110,6 +1346,12 @@ func (h *S3Handler) handleAbortMultipartUpload(w http.ResponseWriter, r *http.Re log.Printf("Failed to abort multipart upload for key %s: %v", key, err) } + // 更新上传会话状态为已中止 + if err := h.storage.UpdateUploadSession(uploadID, 0, "aborted"); err != nil { + log.Printf("Failed to update upload session status to aborted for uploadID %s: %v", uploadID, err) + // 不影响主流程 + } + // 如果是虚拟存储桶,还需要删除文件级别映射 if requestedBucket.IsVirtual() { h.storage.DeleteVirtualBucketFileMapping(bucketName, key) diff --git a/internal/storage/models.go b/internal/storage/models.go index a187459..95918a4 100644 --- a/internal/storage/models.go +++ b/internal/storage/models.go @@ -61,7 +61,7 @@ func (VirtualBucketMapping) TableName() string { // UploadSession 上传会话模型(用于跟踪分片上传) type UploadSession struct { ID uint `gorm:"primaryKey" json:"id"` - UploadID string `gorm:"uniqueIndex;size:255;not null" json:"upload_id"` + UploadID string `gorm:"uniqueIndex;size:512;not null" json:"upload_id"` // 增加到512字符以支持长uploadID Key string `gorm:"index;size:512;not null" json:"key"` BucketName string `gorm:"index;size:255;not null" json:"bucket_name"` TotalParts int `gorm:"not null;default:0" json:"total_parts"` diff --git a/internal/storage/service.go b/internal/storage/service.go index ab908c4..d1c6f99 100644 --- a/internal/storage/service.go +++ b/internal/storage/service.go @@ -301,6 +301,41 @@ func (s *Service) UpdateUploadSession(uploadID string, completedParts int, statu return nil } +// GetPendingUploadSessions 获取正在进行中的上传会话 +func (s *Service) GetPendingUploadSessions(prefix string, keyMarker string, uploadIdMarker string, maxUploads int) ([]*UploadSession, error) { + query := s.db.Model(&UploadSession{}).Where("status = ?", "pending") + + // 根据前缀过滤 + if prefix != "" { + query = query.Where("`key` LIKE ?", prefix+"%") + } + + // 分页标记处理 + if keyMarker != "" { + if uploadIdMarker != "" { + // 如果同时指定了key和uploadId标记 + query = query.Where("(`key` > ? OR (`key` = ? AND upload_id > ?))", keyMarker, keyMarker, uploadIdMarker) + } else { + query = query.Where("`key` > ?", keyMarker) + } + } + + // 限制返回数量 + if maxUploads > 0 { + query = query.Limit(maxUploads + 1) // 多查询一个以判断是否截断 + } + + // 按key和uploadID排序 + query = query.Order("`key` ASC, upload_id ASC") + + var sessions []*UploadSession + if err := query.Find(&sessions).Error; err != nil { + return nil, fmt.Errorf("failed to get pending upload sessions: %w", err) + } + + return sessions, nil +} + // CleanExpiredSessions 清理过期的上传会话 func (s *Service) CleanExpiredSessions() error { if err := s.db.Where("expires_at < ? AND status = ?", time.Now(), "pending").