Support upload sessions

This commit is contained in:
DullJZ
2025-09-11 16:47:25 +08:00
parent 7311ffeeae
commit ffe5c497ae
4 changed files with 347 additions and 19 deletions

View File

@@ -65,6 +65,9 @@ func main() {
// 创建存储服务
storageService := storage.NewService(database.GetDB())
// 启动定期清理过期上传会话的任务
startSessionCleaner(ctx, storageService)
// 创建S3兼容API处理器
s3Handler := api.NewS3Handler(
bucketManager,
@@ -180,3 +183,51 @@ func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// startSessionCleaner 启动定期清理过期会话的任务
func startSessionCleaner(ctx context.Context, storageService *storage.Service) {
go func() {
// 初始延迟,避免启动时立即执行
time.Sleep(1 * time.Minute)
// 每小时清理一次过期的上传会话
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Println("Stopping session cleaner")
return
case <-ticker.C:
log.Println("Cleaning expired upload sessions...")
if err := storageService.CleanExpiredSessions(); err != nil {
log.Printf("Failed to clean expired sessions: %v", err)
} else {
log.Println("Successfully cleaned expired upload sessions")
}
// 同时中止在S3存储桶中过期的分片上传
cleanupS3MultipartUploads(ctx, storageService)
}
}
}()
}
// cleanupS3MultipartUploads 清理S3存储桶中过期的分片上传
func cleanupS3MultipartUploads(ctx context.Context, storageService *storage.Service) {
// 获取所有过期的会话
sessions, err := storageService.GetPendingUploadSessions("", "", "", 0)
if err != nil {
log.Printf("Failed to get pending sessions for cleanup: %v", err)
return
}
for _, session := range sessions {
if session.IsExpired() {
log.Printf("Found expired session: uploadID=%s, key=%s", session.UploadID, session.Key)
// 此处可以添加调用S3 AbortMultipartUpload的逻辑
// 但需要知道对应的真实存储桶信息
}
}
}

View File

@@ -843,6 +843,18 @@ func (h *S3Handler) handleUploadPart(w http.ResponseWriter, r *http.Request) {
if etag != "" {
w.Header().Set("ETag", etag)
}
// 更新上传会话的分片数
session, err := h.storage.GetUploadSession(uploadID)
if err != nil {
log.Printf("Failed to get upload session for uploadID %s: %v", uploadID, err)
} else {
// 更新已完成的分片数
if err := h.storage.UpdateUploadSession(uploadID, session.CompletedParts+1, "pending"); err != nil {
log.Printf("Failed to update upload session for uploadID %s: %v", uploadID, err)
}
}
w.WriteHeader(http.StatusOK)
} else {
// 读取错误响应体以获取详细信息
@@ -899,11 +911,19 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request
return
}
uploadID := *createResp.UploadId
// 记录上传会话到数据库
if err := h.storage.RecordUploadSession(uploadID, key, targetBucket.Config.Name, 0, 0); err != nil {
log.Printf("Failed to record upload session for uploadID %s: %v", uploadID, err)
// 不影响主流程,继续处理
}
result := InitiateMultipartUploadResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Bucket: bucketName, // 返回虚拟存储桶名称给客户端
Key: key,
UploadID: *createResp.UploadId,
UploadID: uploadID,
}
h.sendXMLResponse(w, http.StatusOK, result)
@@ -913,22 +933,124 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request
func (h *S3Handler) handleListMultipartUploads(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
bucketName := vars["bucket"]
key := vars["key"]
// 检查bucket是否存在
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
// 解析查询参数
queryParams := r.URL.Query()
keyMarker := queryParams.Get("key-marker")
uploadIdMarker := queryParams.Get("upload-id-marker")
prefix := queryParams.Get("prefix")
delimiter := queryParams.Get("delimiter")
maxUploadsStr := queryParams.Get("max-uploads")
maxUploads := 1000
if maxUploadsStr != "" {
if m, err := strconv.Atoi(maxUploadsStr); err == nil && m > 0 {
maxUploads = m
}
}
// 检查请求的存储桶
requestedBucket, ok := h.bucketManager.GetBucket(bucketName)
if !ok {
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 简化实现:返回空列表
var allUploads []Upload
var isTruncated bool
// 如果是虚拟存储桶,从数据库查询上传会话
if requestedBucket.IsVirtual() {
// 从数据库获取待处理的上传会话
sessions, err := h.storage.GetPendingUploadSessions(prefix, keyMarker, uploadIdMarker, maxUploads)
if err != nil {
log.Printf("Failed to get pending upload sessions: %v", err)
// 降级到遍历所有存储桶的方式
ctx := context.Background()
allBuckets := h.bucketManager.GetAllBuckets()
for _, bucket := range allBuckets {
if bucket.IsVirtual() {
continue
}
// 列出每个真实存储桶的分片上传
listResp, err := bucket.Client.ListMultipartUploads(ctx, &s3.ListMultipartUploadsInput{
Bucket: aws.String(bucket.Config.Name),
KeyMarker: aws.String(keyMarker),
UploadIdMarker: aws.String(uploadIdMarker),
Prefix: aws.String(prefix),
Delimiter: aws.String(delimiter),
MaxUploads: aws.Int32(int32(maxUploads)),
})
if err != nil {
log.Printf("Failed to list multipart uploads for bucket %s: %v", bucket.Config.Name, err)
continue
}
// 将结果添加到列表中
for _, upload := range listResp.Uploads {
allUploads = append(allUploads, Upload{
Key: aws.ToString(upload.Key),
UploadID: aws.ToString(upload.UploadId),
Initiator: Owner{
ID: aws.ToString(upload.Initiator.ID),
DisplayName: aws.ToString(upload.Initiator.DisplayName),
},
Owner: Owner{
ID: aws.ToString(upload.Owner.ID),
DisplayName: aws.ToString(upload.Owner.DisplayName),
},
StorageClass: string(upload.StorageClass),
Initiated: aws.ToTime(upload.Initiated),
})
}
}
} else {
// 成功从数据库获取会话
if len(sessions) > maxUploads {
sessions = sessions[:maxUploads]
isTruncated = true
}
// 转换会话为Upload格式
for _, session := range sessions {
allUploads = append(allUploads, Upload{
Key: session.Key,
UploadID: session.UploadID,
Initiator: Owner{
ID: "s3-balance",
DisplayName: "S3 Balance User",
},
Owner: Owner{
ID: "s3-balance",
DisplayName: "S3 Balance User",
},
StorageClass: "STANDARD",
Initiated: session.CreatedAt,
})
}
}
} else {
// 如果不是虚拟存储桶,拒绝客户端对真实存储桶的直接操作
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 构建响应
result := ListMultipartUploadsResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Bucket: bucketName,
KeyMarker: key,
MaxUploads: 1000,
IsTruncated: false,
Uploads: make([]Upload, 0),
KeyMarker: keyMarker,
UploadIdMarker: uploadIdMarker,
MaxUploads: maxUploads,
IsTruncated: isTruncated,
Uploads: allUploads,
}
// 如果有更多结果,设置下一个标记
if isTruncated && len(allUploads) > 0 {
lastUpload := allUploads[len(allUploads)-1]
result.NextKeyMarker = lastUpload.Key
result.NextUploadIdMarker = lastUpload.UploadID
}
h.sendXMLResponse(w, http.StatusOK, result)
@@ -941,22 +1063,117 @@ func (h *S3Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Requ
key := vars["key"]
uploadID := r.URL.Query().Get("uploadId")
// 检查bucket是否存在
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
// 解析查询参数
queryParams := r.URL.Query()
partNumberMarkerStr := queryParams.Get("part-number-marker")
partNumberMarker := 0
if partNumberMarkerStr != "" {
if m, err := strconv.Atoi(partNumberMarkerStr); err == nil && m > 0 {
partNumberMarker = m
}
}
maxPartsStr := queryParams.Get("max-parts")
maxParts := 1000
if maxPartsStr != "" {
if m, err := strconv.Atoi(maxPartsStr); err == nil && m > 0 {
maxParts = m
}
}
// 检查请求的存储桶
requestedBucket, ok := h.bucketManager.GetBucket(bucketName)
if !ok {
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 简化实现:返回空列表
var targetBucket *bucket.BucketInfo
// 如果是虚拟存储桶,需要通过映射查找真实存储桶
if requestedBucket.IsVirtual() {
// 获取虚拟存储桶映射
mapping, err := h.storage.GetVirtualBucketMapping(bucketName, key)
if err != nil {
// 如果没有找到映射,尝试查询所有真实存储桶
allBuckets := h.bucketManager.GetAllBuckets()
for _, bucket := range allBuckets {
if bucket.IsVirtual() {
continue
}
// 尝试列出分片,如果成功则说明上传在这个桶中
ctx := context.Background()
_, err := bucket.Client.ListParts(ctx, &s3.ListPartsInput{
Bucket: aws.String(bucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)),
MaxParts: aws.Int32(1), // 只检查是否存在
})
if err == nil {
targetBucket = bucket
break
}
}
if targetBucket == nil {
h.sendS3Error(w, "NoSuchUpload", "The specified multipart upload does not exist", uploadID)
return
}
} else {
// 获取映射到的真实存储桶
targetBucket, ok = h.bucketManager.GetBucket(mapping.RealBucketName)
if !ok {
h.sendS3Error(w, "InternalError", "Mapped real bucket not found", key)
return
}
}
} else {
// 如果不是虚拟存储桶,拒绝客户端对真实存储桶的直接操作
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
return
}
// 列出分片
ctx := context.Background()
listResp, err := targetBucket.Client.ListParts(ctx, &s3.ListPartsInput{
Bucket: aws.String(targetBucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)),
MaxParts: aws.Int32(int32(maxParts)),
})
if err != nil {
h.sendS3Error(w, "NoSuchUpload", "The specified multipart upload does not exist", uploadID)
return
}
// 转换分片列表
var parts []Part
for _, part := range listResp.Parts {
parts = append(parts, Part{
PartNumber: int(aws.ToInt32(part.PartNumber)),
LastModified: aws.ToTime(part.LastModified),
ETag: aws.ToString(part.ETag),
Size: aws.ToInt64(part.Size),
})
}
// 构建响应
result := ListPartsResult{
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
Bucket: bucketName,
Bucket: bucketName, // 返回虚拟存储桶名称给客户端
Key: key,
UploadID: uploadID,
PartNumberMarker: 0,
MaxParts: 1000,
IsTruncated: false,
Parts: make([]Part, 0),
PartNumberMarker: partNumberMarker,
MaxParts: maxParts,
IsTruncated: aws.ToBool(listResp.IsTruncated),
Parts: parts,
}
// 设置下一个分片标记
if listResp.NextPartNumberMarker != nil {
if nextMarker, err := strconv.Atoi(aws.ToString(listResp.NextPartNumberMarker)); err == nil {
result.NextPartNumberMarker = nextMarker
}
}
h.sendXMLResponse(w, http.StatusOK, result)
@@ -1002,7 +1219,18 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
// 解析请求体以获取分片列表
var completeReq CompleteMultipartUpload
body, _ := io.ReadAll(r.Body)
xml.Unmarshal(body, &completeReq)
err := xml.Unmarshal(body, &completeReq)
if err != nil {
log.Printf("Failed to parse CompleteMultipartUpload request body: %v, body: %s", err, string(body))
h.sendS3Error(w, "MalformedXML", "The XML you provided was not well-formed", key)
return
}
log.Printf("CompleteMultipartUpload request - Bucket: %s, Key: %s, UploadID: %s, Parts: %d",
bucketName, key, uploadID, len(completeReq.Parts))
for i, part := range completeReq.Parts {
log.Printf(" Part %d: PartNumber=%d, ETag=%s", i+1, part.PartNumber, part.ETag)
}
// 完成分片上传
ctx := context.Background()
@@ -1014,6 +1242,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
})
}
log.Printf("Calling CompleteMultipartUpload on real bucket %s with uploadID %s", targetBucket.Config.Name, uploadID)
completeResp, err := targetBucket.Client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(targetBucket.Config.Name),
Key: aws.String(key),
@@ -1023,6 +1252,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
},
})
if err != nil {
log.Printf("CompleteMultipartUpload failed: %v", err)
h.sendS3Error(w, "InternalError", "Failed to complete multipart upload", key)
return
}
@@ -1057,6 +1287,12 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
targetBucket.UpdateUsedSize(objectSize)
}
// 更新上传会话状态为已完成
if err := h.storage.UpdateUploadSession(uploadID, len(completeReq.Parts), "completed"); err != nil {
log.Printf("Failed to update upload session status to completed for uploadID %s: %v", uploadID, err)
// 不影响主流程
}
h.sendXMLResponse(w, http.StatusOK, result)
}
@@ -1110,6 +1346,12 @@ func (h *S3Handler) handleAbortMultipartUpload(w http.ResponseWriter, r *http.Re
log.Printf("Failed to abort multipart upload for key %s: %v", key, err)
}
// 更新上传会话状态为已中止
if err := h.storage.UpdateUploadSession(uploadID, 0, "aborted"); err != nil {
log.Printf("Failed to update upload session status to aborted for uploadID %s: %v", uploadID, err)
// 不影响主流程
}
// 如果是虚拟存储桶,还需要删除文件级别映射
if requestedBucket.IsVirtual() {
h.storage.DeleteVirtualBucketFileMapping(bucketName, key)

View File

@@ -61,7 +61,7 @@ func (VirtualBucketMapping) TableName() string {
// UploadSession 上传会话模型(用于跟踪分片上传)
type UploadSession struct {
ID uint `gorm:"primaryKey" json:"id"`
UploadID string `gorm:"uniqueIndex;size:255;not null" json:"upload_id"`
UploadID string `gorm:"uniqueIndex;size:512;not null" json:"upload_id"` // 增加到512字符以支持长uploadID
Key string `gorm:"index;size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255;not null" json:"bucket_name"`
TotalParts int `gorm:"not null;default:0" json:"total_parts"`

View File

@@ -301,6 +301,41 @@ func (s *Service) UpdateUploadSession(uploadID string, completedParts int, statu
return nil
}
// GetPendingUploadSessions 获取正在进行中的上传会话
func (s *Service) GetPendingUploadSessions(prefix string, keyMarker string, uploadIdMarker string, maxUploads int) ([]*UploadSession, error) {
query := s.db.Model(&UploadSession{}).Where("status = ?", "pending")
// 根据前缀过滤
if prefix != "" {
query = query.Where("`key` LIKE ?", prefix+"%")
}
// 分页标记处理
if keyMarker != "" {
if uploadIdMarker != "" {
// 如果同时指定了key和uploadId标记
query = query.Where("(`key` > ? OR (`key` = ? AND upload_id > ?))", keyMarker, keyMarker, uploadIdMarker)
} else {
query = query.Where("`key` > ?", keyMarker)
}
}
// 限制返回数量
if maxUploads > 0 {
query = query.Limit(maxUploads + 1) // 多查询一个以判断是否截断
}
// 按key和uploadID排序
query = query.Order("`key` ASC, upload_id ASC")
var sessions []*UploadSession
if err := query.Find(&sessions).Error; err != nil {
return nil, fmt.Errorf("failed to get pending upload sessions: %w", err)
}
return sessions, nil
}
// CleanExpiredSessions 清理过期的上传会话
func (s *Service) CleanExpiredSessions() error {
if err := s.db.Where("expires_at < ? AND status = ?", time.Now(), "pending").