Record size when multipart upload

This commit is contained in:
DullJZ
2025-10-02 19:15:42 +08:00
parent c2de062070
commit 43bc339943
3 changed files with 173 additions and 82 deletions

View File

@@ -76,6 +76,28 @@ func (h *S3Handler) handleUploadPart(w http.ResponseWriter, r *http.Request) {
return
}
// 检查当前已上传大小 + 本次分片大小是否超过bucket剩余空间
currentSize, err := h.storage.GetUploadSessionSize(uploadID)
if err != nil {
log.Printf("Warning: failed to get upload session size for uploadID %s: %v", uploadID, err)
// 继续处理,不阻止上传
currentSize = 0
}
projectedSize := currentSize + contentLength
availableSpace := targetBucket.GetAvailableSpace()
if projectedSize > availableSpace {
// 空间不足,自动中止后端分片上传
log.Printf("Upload would exceed bucket capacity for key %s, aborting multipart upload. Current: %d bytes, Part: %d bytes, Available: %d bytes",
key, currentSize, contentLength, availableSpace)
h.abortMultipartUploadInternal(targetBucket, key, uploadID)
h.sendS3Error(w, "EntityTooLarge",
fmt.Sprintf("Upload would exceed bucket capacity. Current: %d bytes, Part: %d bytes, Available: %d bytes",
currentSize, contentLength, availableSpace), key)
return
}
// 转换partNumber为整数
partNum, err := strconv.Atoi(partNumber)
if err != nil {
@@ -131,7 +153,7 @@ func (h *S3Handler) handleUploadPart(w http.ResponseWriter, r *http.Request) {
w.Header().Set("ETag", etag)
}
// 更新上传会话的分片数
// 更新上传会话的分片数和累积大小
session, err := h.storage.GetUploadSession(uploadID)
if err != nil {
log.Printf("Failed to get upload session for uploadID %s: %v", uploadID, err)
@@ -140,6 +162,10 @@ func (h *S3Handler) handleUploadPart(w http.ResponseWriter, r *http.Request) {
if err := h.storage.UpdateUploadSession(uploadID, session.CompletedParts+1, "pending"); err != nil {
log.Printf("Failed to update upload session for uploadID %s: %v", uploadID, err)
}
// 累加分片大小
if err := h.storage.IncrementUploadSessionSize(uploadID, contentLength); err != nil {
log.Printf("Failed to increment upload session size for uploadID %s: %v", uploadID, err)
}
}
w.WriteHeader(http.StatusOK)
@@ -201,7 +227,7 @@ func (h *S3Handler) handleMultipartUpload(w http.ResponseWriter, r *http.Request
uploadID := *createResp.UploadId
// 记录上传会话到数据库
if err := h.storage.RecordUploadSession(uploadID, key, targetBucket.Config.Name, 0, 0); err != nil {
if err := h.storage.RecordUploadSession(uploadID, key, targetBucket.Config.Name, 0); err != nil {
log.Printf("Failed to record upload session for uploadID %s: %v", uploadID, err)
// 不影响主流程,继续处理
}
@@ -519,6 +545,29 @@ func (h *S3Handler) handleCompleteMultipartUpload(w http.ResponseWriter, r *http
log.Printf(" Part %d: PartNumber=%d, ETag=%s", i+1, part.PartNumber, part.ETag)
}
// 最终检查验证累积大小是否超过bucket可用空间
totalSize, err := h.storage.GetUploadSessionSize(uploadID)
if err != nil {
log.Printf("Warning: failed to get upload session size for uploadID %s: %v", uploadID, err)
// 继续处理,不阻止完成操作
totalSize = 0
}
if totalSize > 0 {
availableSpace := targetBucket.GetAvailableSpace()
if totalSize > availableSpace {
// 空间不足,自动中止后端分片上传
log.Printf("Upload size exceeds bucket capacity for key %s, aborting multipart upload. Total: %d bytes, Available: %d bytes",
key, totalSize, availableSpace)
h.abortMultipartUploadInternal(targetBucket, key, uploadID)
h.sendS3Error(w, "EntityTooLarge",
fmt.Sprintf("Upload size exceeds bucket capacity. Total: %d bytes, Available: %d bytes",
totalSize, availableSpace), key)
return
}
}
// 完成分片上传
ctx := context.Background()
sort.SliceStable(completeReq.Parts, func(i, j int) bool {
@@ -607,6 +656,28 @@ func getAPIError(err error) (smithy.APIError, bool) {
return nil, false
}
// abortMultipartUploadInternal 内部方法向后端S3发送中止分片上传请求
func (h *S3Handler) abortMultipartUploadInternal(targetBucket *bucket.BucketInfo, key, uploadID string) error {
ctx := context.Background()
_, err := targetBucket.Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(targetBucket.Config.Name),
Key: aws.String(key),
UploadId: aws.String(uploadID),
})
if err != nil {
log.Printf("Failed to abort multipart upload for key %s, uploadID %s: %v", key, uploadID, err)
return err
}
// 更新上传会话状态为已中止
if err := h.storage.UpdateUploadSession(uploadID, 0, "aborted"); err != nil {
log.Printf("Failed to update upload session status to aborted for uploadID %s: %v", uploadID, err)
}
log.Printf("Successfully aborted multipart upload for key %s, uploadID %s", key, uploadID)
return nil
}
// handleAbortMultipartUpload 中止分片上传
func (h *S3Handler) handleAbortMultipartUpload(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)

View File

@@ -10,16 +10,16 @@ import (
// Object 对象信息模型
type Object struct {
ID uint `gorm:"primaryKey" json:"id"`
Key string `gorm:"size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255;not null" json:"bucket_name"`
Size int64 `gorm:"not null;default:0" json:"size"`
Metadata JSON `gorm:"type:json" json:"metadata,omitempty"`
ContentType string `gorm:"size:128" json:"content_type,omitempty"`
ETag string `gorm:"size:128" json:"etag,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
ID uint `gorm:"primaryKey" json:"id"`
Key string `gorm:"size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255;not null" json:"bucket_name"`
Size int64 `gorm:"not null;default:0" json:"size"`
Metadata JSON `gorm:"type:json" json:"metadata,omitempty"`
ContentType string `gorm:"size:128" json:"content_type,omitempty"`
ETag string `gorm:"size:128" json:"etag,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// TableName 指定表名
@@ -45,12 +45,12 @@ func (BucketStats) TableName() string {
// VirtualBucketMapping 虚拟存储桶文件级映射模型
type VirtualBucketMapping struct {
ID uint `gorm:"primaryKey" json:"id"`
ID uint `gorm:"primaryKey" json:"id"`
VirtualBucketName string `gorm:"index;size:255;not null" json:"virtual_bucket_name"`
ObjectKey string `gorm:"index;size:512;not null" json:"object_key"`
RealBucketName string `gorm:"index;size:255;not null" json:"real_bucket_name"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
ObjectKey string `gorm:"index;size:512;not null" json:"object_key"`
RealBucketName string `gorm:"index;size:255;not null" json:"real_bucket_name"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
// TableName 指定表名
@@ -64,7 +64,6 @@ type UploadSession struct {
UploadID string `gorm:"uniqueIndex;size:512;not null" json:"upload_id"` // 增加到512字符以支持长uploadID
Key string `gorm:"index;size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255;not null" json:"bucket_name"`
TotalParts int `gorm:"not null;default:0" json:"total_parts"`
CompletedParts int `gorm:"not null;default:0" json:"completed_parts"`
Size int64 `gorm:"not null;default:0" json:"size"`
Status string `gorm:"size:32;not null;default:'pending'" json:"status"` // pending, completed, aborted
@@ -81,17 +80,17 @@ func (UploadSession) TableName() string {
// AccessLog 访问日志模型
type AccessLog struct {
ID uint `gorm:"primaryKey" json:"id"`
Action string `gorm:"index;size:32;not null" json:"action"` // upload, download, delete
Key string `gorm:"index;size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255" json:"bucket_name"`
Size int64 `gorm:"default:0" json:"size"`
ClientIP string `gorm:"size:64" json:"client_ip"`
UserAgent string `gorm:"size:512" json:"user_agent"`
Success bool `gorm:"default:true" json:"success"`
ErrorMsg string `gorm:"type:text" json:"error_msg,omitempty"`
ResponseTime int64 `gorm:"default:0" json:"response_time"` // 响应时间(毫秒)
CreatedAt time.Time `gorm:"index" json:"created_at"`
ID uint `gorm:"primaryKey" json:"id"`
Action string `gorm:"index;size:32;not null" json:"action"` // upload, download, delete
Key string `gorm:"index;size:512;not null" json:"key"`
BucketName string `gorm:"index;size:255" json:"bucket_name"`
Size int64 `gorm:"default:0" json:"size"`
ClientIP string `gorm:"size:64" json:"client_ip"`
UserAgent string `gorm:"size:512" json:"user_agent"`
Success bool `gorm:"default:true" json:"success"`
ErrorMsg string `gorm:"type:text" json:"error_msg,omitempty"`
ResponseTime int64 `gorm:"default:0" json:"response_time"` // 响应时间(毫秒)
CreatedAt time.Time `gorm:"index" json:"created_at"`
}
// TableName 指定表名
@@ -116,7 +115,7 @@ func (j *JSON) Scan(value interface{}) error {
*j = make(map[string]interface{})
return nil
}
var data []byte
switch v := value.(type) {
case []byte:
@@ -126,7 +125,7 @@ func (j *JSON) Scan(value interface{}) error {
default:
data = []byte("{}")
}
return json.Unmarshal(data, j)
}

View File

@@ -29,13 +29,13 @@ func (s *Service) RecordObject(key, bucketName string, size int64, metadata map[
return fmt.Errorf("failed to permanently delete soft-deleted object: %w", err)
}
}
obj := &Object{
Key: key,
BucketName: bucketName,
Size: size,
}
if len(metadata) > 0 {
obj.Metadata = make(JSON)
for k, v := range metadata {
@@ -44,13 +44,13 @@ func (s *Service) RecordObject(key, bucketName string, size int64, metadata map[
} else {
obj.Metadata = make(JSON)
}
// 使用 Upsert更新或插入
result := s.db.Where("`key` = ?", key).Where("`deleted_at` IS NULL").FirstOrCreate(&obj)
if result.Error != nil {
return fmt.Errorf("failed to record object: %w", result.Error)
}
if result.RowsAffected == 0 {
// 对象已存在,更新它
updates := map[string]interface{}{
@@ -63,10 +63,10 @@ func (s *Service) RecordObject(key, bucketName string, size int64, metadata map[
return fmt.Errorf("failed to update object: %w", err)
}
}
// 更新存储桶统计
s.updateBucketStats(bucketName)
return nil
}
@@ -103,17 +103,17 @@ func (s *Service) DeleteObject(key string) error {
}
return fmt.Errorf("failed to find object: %w", err)
}
bucketName := obj.BucketName
// 软删除
if err := s.db.Delete(&obj).Error; err != nil {
return fmt.Errorf("failed to delete object: %w", err)
}
// 更新存储桶统计
s.updateBucketStats(bucketName)
return nil
}
@@ -121,32 +121,32 @@ func (s *Service) DeleteObject(key string) error {
func (s *Service) ListObjects(bucketName, prefix, marker string, maxKeys int) ([]*Object, error) {
var objects []*Object
query := s.db.Model(&Object{})
// 按bucket过滤
if bucketName != "" {
query = query.Where("bucket_name = ?", bucketName)
}
// 前缀过滤
if prefix != "" {
query = query.Where("`key` LIKE ?", prefix+"%")
}
// Marker分页
if marker != "" {
query = query.Where("`key` > ?", marker)
}
// 限制返回数量
if maxKeys > 0 {
query = query.Limit(maxKeys)
}
// 按key字母顺序排序S3标准
if err := query.Order("`key` ASC").Find(&objects).Error; err != nil {
return nil, fmt.Errorf("failed to list objects: %w", err)
}
return objects, nil
}
@@ -203,27 +203,27 @@ func (s *Service) GetBucketObjectCount(bucketName string) (int64, error) {
// updateBucketStats 更新存储桶统计信息
func (s *Service) updateBucketStats(bucketName string) error {
var stats BucketStats
// 计算新的统计数据
var count int64
var totalSize int64
if err := s.db.Model(&Object{}).
Where("bucket_name = ?", bucketName).
Count(&count).Error; err != nil {
return fmt.Errorf("failed to count objects: %w", err)
}
if err := s.db.Model(&Object{}).
Where("bucket_name = ?", bucketName).
Select("COALESCE(SUM(size), 0)").
Scan(&totalSize).Error; err != nil {
return fmt.Errorf("failed to sum object sizes: %w", err)
}
// 尝试查找现有记录
err := s.db.Where("bucket_name = ?", bucketName).First(&stats).Error
if err == gorm.ErrRecordNotFound {
// 记录不存在,创建新记录
stats = BucketStats{
@@ -245,30 +245,29 @@ func (s *Service) updateBucketStats(bucketName string) error {
"total_size": totalSize,
"last_checked_at": time.Now(),
}
if err := s.db.Model(&stats).Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update bucket stats: %w", err)
}
}
return nil
}
// RecordUploadSession 记录上传会话
func (s *Service) RecordUploadSession(uploadID, key, bucketName string, totalParts int, size int64) error {
func (s *Service) RecordUploadSession(uploadID, key, bucketName string, size int64) error {
session := &UploadSession{
UploadID: uploadID,
Key: key,
BucketName: bucketName,
TotalParts: totalParts,
Size: size,
Status: "pending",
}
if err := s.db.Create(session).Error; err != nil {
return fmt.Errorf("failed to record upload session: %w", err)
}
return nil
}
@@ -291,25 +290,47 @@ func (s *Service) UpdateUploadSession(uploadID string, completedParts int, statu
"status": status,
"updated_at": time.Now(),
}
if err := s.db.Model(&UploadSession{}).
Where("upload_id = ?", uploadID).
Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update upload session: %w", err)
}
return nil
}
// IncrementUploadSessionSize 增加上传会话的大小(用于累加分片大小)
func (s *Service) IncrementUploadSessionSize(uploadID string, partSize int64) error {
if err := s.db.Model(&UploadSession{}).
Where("upload_id = ?", uploadID).
UpdateColumn("size", gorm.Expr("size + ?", partSize)).Error; err != nil {
return fmt.Errorf("failed to increment upload session size: %w", err)
}
return nil
}
// GetUploadSessionSize 获取上传会话当前累积的大小
func (s *Service) GetUploadSessionSize(uploadID string) (int64, error) {
var session UploadSession
if err := s.db.Select("size").Where("upload_id = ?", uploadID).First(&session).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return 0, fmt.Errorf("upload session not found: %s", uploadID)
}
return 0, fmt.Errorf("failed to get upload session size: %w", err)
}
return session.Size, nil
}
// GetPendingUploadSessions 获取正在进行中的上传会话
func (s *Service) GetPendingUploadSessions(prefix string, keyMarker string, uploadIdMarker string, maxUploads int) ([]*UploadSession, error) {
query := s.db.Model(&UploadSession{}).Where("status = ?", "pending")
// 根据前缀过滤
if prefix != "" {
query = query.Where("`key` LIKE ?", prefix+"%")
}
// 分页标记处理
if keyMarker != "" {
if uploadIdMarker != "" {
@@ -319,20 +340,20 @@ func (s *Service) GetPendingUploadSessions(prefix string, keyMarker string, uplo
query = query.Where("`key` > ?", keyMarker)
}
}
// 限制返回数量
if maxUploads > 0 {
query = query.Limit(maxUploads + 1) // 多查询一个以判断是否截断
}
// 按key和uploadID排序
query = query.Order("`key` ASC, upload_id ASC")
var sessions []*UploadSession
if err := query.Find(&sessions).Error; err != nil {
return nil, fmt.Errorf("failed to get pending upload sessions: %w", err)
}
return sessions, nil
}
@@ -358,18 +379,18 @@ func (s *Service) RecordAccessLog(action, key, bucketName, clientIP, userAgent s
ErrorMsg: errorMsg,
ResponseTime: responseTime,
}
if err := s.db.Create(log).Error; err != nil {
return fmt.Errorf("failed to record access log: %w", err)
}
return nil
}
// GetAccessLogs 获取访问日志
func (s *Service) GetAccessLogs(filter *AccessLogFilter) ([]*AccessLog, error) {
query := s.db.Model(&AccessLog{})
if filter != nil {
if filter.Action != "" {
query = query.Where("action = ?", filter.Action)
@@ -399,12 +420,12 @@ func (s *Service) GetAccessLogs(filter *AccessLogFilter) ([]*AccessLog, error) {
query = query.Offset(filter.Offset)
}
}
var logs []*AccessLog
if err := query.Order("created_at DESC").Find(&logs).Error; err != nil {
return nil, fmt.Errorf("failed to get access logs: %w", err)
}
return logs, nil
}
@@ -412,14 +433,14 @@ func (s *Service) GetAccessLogs(filter *AccessLogFilter) ([]*AccessLog, error) {
func (s *Service) CreateVirtualBucketMapping(virtualBucketName, objectKey, realBucketName string) error {
mapping := &VirtualBucketMapping{
VirtualBucketName: virtualBucketName,
ObjectKey: objectKey,
RealBucketName: realBucketName,
ObjectKey: objectKey,
RealBucketName: realBucketName,
}
if err := s.db.Create(mapping).Error; err != nil {
return fmt.Errorf("failed to create virtual bucket mapping: %w", err)
}
return nil
}
@@ -457,15 +478,15 @@ func (s *Service) GetVirtualBucketMappingsForBucket(virtualBucketName string) ([
func (s *Service) UpdateVirtualBucketMapping(virtualBucketName, objectKey, realBucketName string) error {
updates := map[string]interface{}{
"real_bucket_name": realBucketName,
"updated_at": time.Now(),
"updated_at": time.Now(),
}
if err := s.db.Model(&VirtualBucketMapping{}).
Where("virtual_bucket_name = ? AND object_key = ?", virtualBucketName, objectKey).
Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update virtual bucket mapping: %w", err)
}
return nil
}
@@ -495,23 +516,23 @@ func (s *Service) GetVirtualBucketObjects(virtualBucketName string) ([]*Object,
if err != nil {
return nil, err
}
if len(mappings) == 0 {
return []*Object{}, nil
}
// 收集所有对象键
objectKeys := make([]string, 0, len(mappings))
for _, mapping := range mappings {
objectKeys = append(objectKeys, mapping.ObjectKey)
}
// 从对象表中查询这些对象
var objects []*Object
if err := s.db.Where("`key` IN ?", objectKeys).Find(&objects).Error; err != nil {
return nil, fmt.Errorf("failed to get objects for virtual bucket: %w", err)
}
return objects, nil
}