mirror of
https://github.com/DullJZ/s3-balance.git
synced 2026-06-30 15:41:22 +08:00
Support upload sessions
This commit is contained in:
@@ -65,6 +65,9 @@ func main() {
|
||||
// 创建存储服务
|
||||
storageService := storage.NewService(database.GetDB())
|
||||
|
||||
// 启动定期清理过期上传会话的任务
|
||||
startSessionCleaner(ctx, storageService)
|
||||
|
||||
// 创建S3兼容API处理器
|
||||
s3Handler := api.NewS3Handler(
|
||||
bucketManager,
|
||||
@@ -180,3 +183,51 @@ func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// startSessionCleaner 启动定期清理过期会话的任务
|
||||
func startSessionCleaner(ctx context.Context, storageService *storage.Service) {
|
||||
go func() {
|
||||
// 初始延迟,避免启动时立即执行
|
||||
time.Sleep(1 * time.Minute)
|
||||
|
||||
// 每小时清理一次过期的上传会话
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("Stopping session cleaner")
|
||||
return
|
||||
case <-ticker.C:
|
||||
log.Println("Cleaning expired upload sessions...")
|
||||
if err := storageService.CleanExpiredSessions(); err != nil {
|
||||
log.Printf("Failed to clean expired sessions: %v", err)
|
||||
} else {
|
||||
log.Println("Successfully cleaned expired upload sessions")
|
||||
}
|
||||
|
||||
// 同时中止在S3存储桶中过期的分片上传
|
||||
cleanupS3MultipartUploads(ctx, storageService)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupS3MultipartUploads 清理S3存储桶中过期的分片上传
|
||||
func cleanupS3MultipartUploads(ctx context.Context, storageService *storage.Service) {
|
||||
// 获取所有过期的会话
|
||||
sessions, err := storageService.GetPendingUploadSessions("", "", "", 0)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get pending sessions for cleanup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
if session.IsExpired() {
|
||||
log.Printf("Found expired session: uploadID=%s, key=%s", session.UploadID, session.Key)
|
||||
// 此处可以添加调用S3 AbortMultipartUpload的逻辑
|
||||
// 但需要知道对应的真实存储桶信息
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -843,6 +843,18 @@ func (h *S3Handler) handleUploadPart(w http.ResponseWriter, r *http.Request) {
|
||||
if etag != "" {
|
||||
w.Header().Set("ETag", etag)
|
||||
}
|
||||
|
||||
// 更新上传会话的分片数
|
||||
session, err := h.storage.GetUploadSession(uploadID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get upload session for uploadID %s: %v", uploadID, err)
|
||||
} else {
|
||||
// 更新已完成的分片数
|
||||
if err := h.storage.UpdateUploadSession(uploadID, session.CompletedParts+1, "pending"); err != nil {
|
||||
log.Printf("Failed to update upload session for uploadID %s: %v", uploadID, err)
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
// 读取错误响应体以获取详细信息
|
||||
@@ -899,11 +911,19 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
uploadID := *createResp.UploadId
|
||||
|
||||
// 记录上传会话到数据库
|
||||
if err := h.storage.RecordUploadSession(uploadID, key, targetBucket.Config.Name, 0, 0); err != nil {
|
||||
log.Printf("Failed to record upload session for uploadID %s: %v", uploadID, err)
|
||||
// 不影响主流程,继续处理
|
||||
}
|
||||
|
||||
result := InitiateMultipartUploadResult{
|
||||
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
|
||||
Bucket: bucketName, // 返回虚拟存储桶名称给客户端
|
||||
Key: key,
|
||||
UploadID: *createResp.UploadId,
|
||||
UploadID: uploadID,
|
||||
}
|
||||
|
||||
h.sendXMLResponse(w, http.StatusOK, result)
|
||||
@@ -913,22 +933,124 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request
|
||||
func (h *S3Handler) handleListMultipartUploads(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
bucketName := vars["bucket"]
|
||||
key := vars["key"]
|
||||
|
||||
// 检查bucket是否存在
|
||||
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
|
||||
// 解析查询参数
|
||||
queryParams := r.URL.Query()
|
||||
keyMarker := queryParams.Get("key-marker")
|
||||
uploadIdMarker := queryParams.Get("upload-id-marker")
|
||||
prefix := queryParams.Get("prefix")
|
||||
delimiter := queryParams.Get("delimiter")
|
||||
maxUploadsStr := queryParams.Get("max-uploads")
|
||||
maxUploads := 1000
|
||||
if maxUploadsStr != "" {
|
||||
if m, err := strconv.Atoi(maxUploadsStr); err == nil && m > 0 {
|
||||
maxUploads = m
|
||||
}
|
||||
}
|
||||
|
||||
// 检查请求的存储桶
|
||||
requestedBucket, ok := h.bucketManager.GetBucket(bucketName)
|
||||
if !ok {
|
||||
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
|
||||
return
|
||||
}
|
||||
|
||||
// 简化实现:返回空列表
|
||||
var allUploads []Upload
|
||||
var isTruncated bool
|
||||
|
||||
// 如果是虚拟存储桶,从数据库查询上传会话
|
||||
if requestedBucket.IsVirtual() {
|
||||
// 从数据库获取待处理的上传会话
|
||||
sessions, err := h.storage.GetPendingUploadSessions(prefix, keyMarker, uploadIdMarker, maxUploads)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get pending upload sessions: %v", err)
|
||||
// 降级到遍历所有存储桶的方式
|
||||
ctx := context.Background()
|
||||
allBuckets := h.bucketManager.GetAllBuckets()
|
||||
for _, bucket := range allBuckets {
|
||||
if bucket.IsVirtual() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 列出每个真实存储桶的分片上传
|
||||
listResp, err := bucket.Client.ListMultipartUploads(ctx, &s3.ListMultipartUploadsInput{
|
||||
Bucket: aws.String(bucket.Config.Name),
|
||||
KeyMarker: aws.String(keyMarker),
|
||||
UploadIdMarker: aws.String(uploadIdMarker),
|
||||
Prefix: aws.String(prefix),
|
||||
Delimiter: aws.String(delimiter),
|
||||
MaxUploads: aws.Int32(int32(maxUploads)),
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to list multipart uploads for bucket %s: %v", bucket.Config.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 将结果添加到列表中
|
||||
for _, upload := range listResp.Uploads {
|
||||
allUploads = append(allUploads, Upload{
|
||||
Key: aws.ToString(upload.Key),
|
||||
UploadID: aws.ToString(upload.UploadId),
|
||||
Initiator: Owner{
|
||||
ID: aws.ToString(upload.Initiator.ID),
|
||||
DisplayName: aws.ToString(upload.Initiator.DisplayName),
|
||||
},
|
||||
Owner: Owner{
|
||||
ID: aws.ToString(upload.Owner.ID),
|
||||
DisplayName: aws.ToString(upload.Owner.DisplayName),
|
||||
},
|
||||
StorageClass: string(upload.StorageClass),
|
||||
Initiated: aws.ToTime(upload.Initiated),
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 成功从数据库获取会话
|
||||
if len(sessions) > maxUploads {
|
||||
sessions = sessions[:maxUploads]
|
||||
isTruncated = true
|
||||
}
|
||||
|
||||
// 转换会话为Upload格式
|
||||
for _, session := range sessions {
|
||||
allUploads = append(allUploads, Upload{
|
||||
Key: session.Key,
|
||||
UploadID: session.UploadID,
|
||||
Initiator: Owner{
|
||||
ID: "s3-balance",
|
||||
DisplayName: "S3 Balance User",
|
||||
},
|
||||
Owner: Owner{
|
||||
ID: "s3-balance",
|
||||
DisplayName: "S3 Balance User",
|
||||
},
|
||||
StorageClass: "STANDARD",
|
||||
Initiated: session.CreatedAt,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是虚拟存储桶,拒绝客户端对真实存储桶的直接操作
|
||||
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
result := ListMultipartUploadsResult{
|
||||
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
|
||||
Bucket: bucketName,
|
||||
KeyMarker: key,
|
||||
MaxUploads: 1000,
|
||||
IsTruncated: false,
|
||||
Uploads: make([]Upload, 0),
|
||||
KeyMarker: keyMarker,
|
||||
UploadIdMarker: uploadIdMarker,
|
||||
MaxUploads: maxUploads,
|
||||
IsTruncated: isTruncated,
|
||||
Uploads: allUploads,
|
||||
}
|
||||
|
||||
// 如果有更多结果,设置下一个标记
|
||||
if isTruncated && len(allUploads) > 0 {
|
||||
lastUpload := allUploads[len(allUploads)-1]
|
||||
result.NextKeyMarker = lastUpload.Key
|
||||
result.NextUploadIdMarker = lastUpload.UploadID
|
||||
}
|
||||
|
||||
h.sendXMLResponse(w, http.StatusOK, result)
|
||||
@@ -941,22 +1063,117 @@ func (h *S3Handler) handleListMultipartParts(w http.ResponseWriter, r *http.Requ
|
||||
key := vars["key"]
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
|
||||
// 检查bucket是否存在
|
||||
if _, ok := h.bucketManager.GetBucket(bucketName); !ok {
|
||||
// 解析查询参数
|
||||
queryParams := r.URL.Query()
|
||||
partNumberMarkerStr := queryParams.Get("part-number-marker")
|
||||
partNumberMarker := 0
|
||||
if partNumberMarkerStr != "" {
|
||||
if m, err := strconv.Atoi(partNumberMarkerStr); err == nil && m > 0 {
|
||||
partNumberMarker = m
|
||||
}
|
||||
}
|
||||
maxPartsStr := queryParams.Get("max-parts")
|
||||
maxParts := 1000
|
||||
if maxPartsStr != "" {
|
||||
if m, err := strconv.Atoi(maxPartsStr); err == nil && m > 0 {
|
||||
maxParts = m
|
||||
}
|
||||
}
|
||||
|
||||
// 检查请求的存储桶
|
||||
requestedBucket, ok := h.bucketManager.GetBucket(bucketName)
|
||||
if !ok {
|
||||
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
|
||||
return
|
||||
}
|
||||
|
||||
// 简化实现:返回空列表
|
||||
var targetBucket *bucket.BucketInfo
|
||||
|
||||
// 如果是虚拟存储桶,需要通过映射查找真实存储桶
|
||||
if requestedBucket.IsVirtual() {
|
||||
// 获取虚拟存储桶映射
|
||||
mapping, err := h.storage.GetVirtualBucketMapping(bucketName, key)
|
||||
if err != nil {
|
||||
// 如果没有找到映射,尝试查询所有真实存储桶
|
||||
allBuckets := h.bucketManager.GetAllBuckets()
|
||||
for _, bucket := range allBuckets {
|
||||
if bucket.IsVirtual() {
|
||||
continue
|
||||
}
|
||||
// 尝试列出分片,如果成功则说明上传在这个桶中
|
||||
ctx := context.Background()
|
||||
_, err := bucket.Client.ListParts(ctx, &s3.ListPartsInput{
|
||||
Bucket: aws.String(bucket.Config.Name),
|
||||
Key: aws.String(key),
|
||||
UploadId: aws.String(uploadID),
|
||||
PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)),
|
||||
MaxParts: aws.Int32(1), // 只检查是否存在
|
||||
})
|
||||
if err == nil {
|
||||
targetBucket = bucket
|
||||
break
|
||||
}
|
||||
}
|
||||
if targetBucket == nil {
|
||||
h.sendS3Error(w, "NoSuchUpload", "The specified multipart upload does not exist", uploadID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 获取映射到的真实存储桶
|
||||
targetBucket, ok = h.bucketManager.GetBucket(mapping.RealBucketName)
|
||||
if !ok {
|
||||
h.sendS3Error(w, "InternalError", "Mapped real bucket not found", key)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是虚拟存储桶,拒绝客户端对真实存储桶的直接操作
|
||||
h.sendS3Error(w, "NoSuchBucket", "The specified bucket does not exist", bucketName)
|
||||
return
|
||||
}
|
||||
|
||||
// 列出分片
|
||||
ctx := context.Background()
|
||||
listResp, err := targetBucket.Client.ListParts(ctx, &s3.ListPartsInput{
|
||||
Bucket: aws.String(targetBucket.Config.Name),
|
||||
Key: aws.String(key),
|
||||
UploadId: aws.String(uploadID),
|
||||
PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)),
|
||||
MaxParts: aws.Int32(int32(maxParts)),
|
||||
})
|
||||
if err != nil {
|
||||
h.sendS3Error(w, "NoSuchUpload", "The specified multipart upload does not exist", uploadID)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换分片列表
|
||||
var parts []Part
|
||||
for _, part := range listResp.Parts {
|
||||
parts = append(parts, Part{
|
||||
PartNumber: int(aws.ToInt32(part.PartNumber)),
|
||||
LastModified: aws.ToTime(part.LastModified),
|
||||
ETag: aws.ToString(part.ETag),
|
||||
Size: aws.ToInt64(part.Size),
|
||||
})
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
result := ListPartsResult{
|
||||
Xmlns: "http://s3.amazonaws.com/doc/2006-03-01/",
|
||||
Bucket: bucketName,
|
||||
Bucket: bucketName, // 返回虚拟存储桶名称给客户端
|
||||
Key: key,
|
||||
UploadID: uploadID,
|
||||
PartNumberMarker: 0,
|
||||
MaxParts: 1000,
|
||||
IsTruncated: false,
|
||||
Parts: make([]Part, 0),
|
||||
PartNumberMarker: partNumberMarker,
|
||||
MaxParts: maxParts,
|
||||
IsTruncated: aws.ToBool(listResp.IsTruncated),
|
||||
Parts: parts,
|
||||
}
|
||||
|
||||
// 设置下一个分片标记
|
||||
if listResp.NextPartNumberMarker != nil {
|
||||
if nextMarker, err := strconv.Atoi(aws.ToString(listResp.NextPartNumberMarker)); err == nil {
|
||||
result.NextPartNumberMarker = nextMarker
|
||||
}
|
||||
}
|
||||
|
||||
h.sendXMLResponse(w, http.StatusOK, result)
|
||||
@@ -1002,7 +1219,18 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
|
||||
// 解析请求体以获取分片列表
|
||||
var completeReq CompleteMultipartUpload
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
xml.Unmarshal(body, &completeReq)
|
||||
err := xml.Unmarshal(body, &completeReq)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse CompleteMultipartUpload request body: %v, body: %s", err, string(body))
|
||||
h.sendS3Error(w, "MalformedXML", "The XML you provided was not well-formed", key)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("CompleteMultipartUpload request - Bucket: %s, Key: %s, UploadID: %s, Parts: %d",
|
||||
bucketName, key, uploadID, len(completeReq.Parts))
|
||||
for i, part := range completeReq.Parts {
|
||||
log.Printf(" Part %d: PartNumber=%d, ETag=%s", i+1, part.PartNumber, part.ETag)
|
||||
}
|
||||
|
||||
// 完成分片上传
|
||||
ctx := context.Background()
|
||||
@@ -1014,6 +1242,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
|
||||
})
|
||||
}
|
||||
|
||||
log.Printf("Calling CompleteMultipartUpload on real bucket %s with uploadID %s", targetBucket.Config.Name, uploadID)
|
||||
completeResp, err := targetBucket.Client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
|
||||
Bucket: aws.String(targetBucket.Config.Name),
|
||||
Key: aws.String(key),
|
||||
@@ -1023,6 +1252,7 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("CompleteMultipartUpload failed: %v", err)
|
||||
h.sendS3Error(w, "InternalError", "Failed to complete multipart upload", key)
|
||||
return
|
||||
}
|
||||
@@ -1057,6 +1287,12 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
|
||||
targetBucket.UpdateUsedSize(objectSize)
|
||||
}
|
||||
|
||||
// 更新上传会话状态为已完成
|
||||
if err := h.storage.UpdateUploadSession(uploadID, len(completeReq.Parts), "completed"); err != nil {
|
||||
log.Printf("Failed to update upload session status to completed for uploadID %s: %v", uploadID, err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
h.sendXMLResponse(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
@@ -1110,6 +1346,12 @@ func (h *S3Handler) handleAbortMultipartUpload(w http.ResponseWriter, r *http.Re
|
||||
log.Printf("Failed to abort multipart upload for key %s: %v", key, err)
|
||||
}
|
||||
|
||||
// 更新上传会话状态为已中止
|
||||
if err := h.storage.UpdateUploadSession(uploadID, 0, "aborted"); err != nil {
|
||||
log.Printf("Failed to update upload session status to aborted for uploadID %s: %v", uploadID, err)
|
||||
// 不影响主流程
|
||||
}
|
||||
|
||||
// 如果是虚拟存储桶,还需要删除文件级别映射
|
||||
if requestedBucket.IsVirtual() {
|
||||
h.storage.DeleteVirtualBucketFileMapping(bucketName, key)
|
||||
|
||||
@@ -61,7 +61,7 @@ func (VirtualBucketMapping) TableName() string {
|
||||
// UploadSession 上传会话模型(用于跟踪分片上传)
|
||||
type UploadSession struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UploadID string `gorm:"uniqueIndex;size:255;not null" json:"upload_id"`
|
||||
UploadID string `gorm:"uniqueIndex;size:512;not null" json:"upload_id"` // 增加到512字符以支持长uploadID
|
||||
Key string `gorm:"index;size:512;not null" json:"key"`
|
||||
BucketName string `gorm:"index;size:255;not null" json:"bucket_name"`
|
||||
TotalParts int `gorm:"not null;default:0" json:"total_parts"`
|
||||
|
||||
@@ -301,6 +301,41 @@ func (s *Service) UpdateUploadSession(uploadID string, completedParts int, statu
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPendingUploadSessions 获取正在进行中的上传会话
|
||||
func (s *Service) GetPendingUploadSessions(prefix string, keyMarker string, uploadIdMarker string, maxUploads int) ([]*UploadSession, error) {
|
||||
query := s.db.Model(&UploadSession{}).Where("status = ?", "pending")
|
||||
|
||||
// 根据前缀过滤
|
||||
if prefix != "" {
|
||||
query = query.Where("`key` LIKE ?", prefix+"%")
|
||||
}
|
||||
|
||||
// 分页标记处理
|
||||
if keyMarker != "" {
|
||||
if uploadIdMarker != "" {
|
||||
// 如果同时指定了key和uploadId标记
|
||||
query = query.Where("(`key` > ? OR (`key` = ? AND upload_id > ?))", keyMarker, keyMarker, uploadIdMarker)
|
||||
} else {
|
||||
query = query.Where("`key` > ?", keyMarker)
|
||||
}
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if maxUploads > 0 {
|
||||
query = query.Limit(maxUploads + 1) // 多查询一个以判断是否截断
|
||||
}
|
||||
|
||||
// 按key和uploadID排序
|
||||
query = query.Order("`key` ASC, upload_id ASC")
|
||||
|
||||
var sessions []*UploadSession
|
||||
if err := query.Find(&sessions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get pending upload sessions: %w", err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// CleanExpiredSessions 清理过期的上传会话
|
||||
func (s *Service) CleanExpiredSessions() error {
|
||||
if err := s.db.Where("expires_at < ? AND status = ?", time.Now(), "pending").
|
||||
|
||||
Reference in New Issue
Block a user