From 3a4effab33ce57257376ffe330667bcc746d48c5 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Sat, 15 Feb 2025 15:06:06 +0800 Subject: [PATCH] feat: refactor file processing and storage handling with improved path management --- core/core.go | 58 +++++++++++++++------------------------- core/utils.go | 34 ++++++++++++++++++++--- storage/alist/alist.go | 4 --- storage/local/local.go | 1 - storage/storage.go | 16 +++++++++-- storage/webdav/webdav.go | 6 +---- 6 files changed, 68 insertions(+), 51 deletions(-) diff --git a/core/core.go b/core/core.go index 5adb210..5107ad0 100644 --- a/core/core.go +++ b/core/core.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "path" "path/filepath" "time" @@ -21,13 +22,13 @@ import ( func processPendingTask(task *types.Task) error { logger.L.Debugf("Start processing task: %s", task.String()) - destPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName()) - absDestPath, err := filepath.Abs(destPath) + cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName()) + cacheDestPath, err := filepath.Abs(cacheDestPath) if err != nil { - return fmt.Errorf("Failed to get absolute path: %w", err) + return fmt.Errorf("failed to get absolute path: %w", err) } - if err := fileutil.CreateDir(filepath.Dir(absDestPath)); err != nil { - return fmt.Errorf("Failed to create directory: %w", err) + if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil { + return fmt.Errorf("failed to create directory: %w", err) } ctx := task.Ctx.(*ext.Context) @@ -39,32 +40,17 @@ func processPendingTask(task *types.Task) error { if task.StoragePath == "" { task.StoragePath = task.File.FileName } + switch task.Storage { + case types.Local: + task.StoragePath = filepath.Join(config.Cfg.Storage.Local.BasePath, task.StoragePath) + case types.Webdav: + task.StoragePath = path.Join(config.Cfg.Storage.Webdav.BasePath, task.StoragePath) + case types.Alist: + task.StoragePath = path.Join(config.Cfg.Storage.Alist.BasePath, task.StoragePath) + } - // process photo if task.File.FileSize == 0 { - res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{ - Location: task.File.Location, - Offset: 0, - Limit: 1024 * 1024, - }) - if err != nil { - return fmt.Errorf("Failed to get file: %w", err) - } - - result, ok := res.(*tg.UploadFile) - if !ok { - return fmt.Errorf("unexpected type %T", res) - } - - if err := os.WriteFile(destPath, result.Bytes, os.ModePerm); err != nil { - return fmt.Errorf("Failed to write file: %w", err) - } - - defer cleanCacheFile(destPath) - - logger.L.Infof("Downloaded file: %s", destPath) - - return saveFileWithRetry(task, destPath) + return processPhoto(task, cacheDestPath) } barTotalCount := calculateBarTotalCount(task.File.FileSize) @@ -92,29 +78,29 @@ func processPendingTask(task *types.Task) error { 0, task.File.FileSize-1, task.File.FileSize, progressCallback, task.File.FileSize/100) if err != nil { - return fmt.Errorf("Failed to create reader: %w", err) + return fmt.Errorf("failed to create reader: %w", err) } defer readCloser.Close() - dest, err := os.Create(destPath) + dest, err := os.Create(cacheDestPath) if err != nil { - return fmt.Errorf("Failed to create file: %w", err) + return fmt.Errorf("failed to create file: %w", err) } defer dest.Close() task.StartTime = time.Now() if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil { - return fmt.Errorf("Failed to download file: %w", err) + return fmt.Errorf("failed to download file: %w", err) } - defer cleanCacheFile(destPath) + defer cleanCacheFile(cacheDestPath) - logger.L.Infof("Downloaded file: %s", destPath) + logger.L.Infof("Downloaded file: %s", cacheDestPath) ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{ Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), ID: task.ReplyMessageID, }) - return saveFileWithRetry(task, destPath) + return saveFileWithRetry(task, cacheDestPath) } func worker(queue *queue.TaskQueue, semaphore chan struct{}) { diff --git a/core/utils.go b/core/utils.go index 70951cd..a3b02df 100644 --- a/core/utils.go +++ b/core/utils.go @@ -5,6 +5,8 @@ import ( "os" "time" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/bot" "github.com/krau/SaveAny-Bot/common" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/logger" @@ -12,11 +14,11 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func saveFileWithRetry(task *types.Task, destPath string) error { +func saveFileWithRetry(task *types.Task, localFilePath string) error { for i := 0; i <= config.Cfg.Retry; i++ { - if err := storage.Save(task.Storage, task.Ctx, destPath, task.StoragePath); err != nil { + if err := storage.Save(task.Storage, task.Ctx, localFilePath, task.StoragePath); err != nil { if i == config.Cfg.Retry { - return fmt.Errorf("Failed to save file: %w", err) + return fmt.Errorf("failed to save file: %w", err) } logger.L.Errorf("Failed to save file: %s, retrying...", err) continue @@ -26,6 +28,32 @@ func saveFileWithRetry(task *types.Task, destPath string) error { return nil } +func processPhoto(task *types.Task, cachePath string) error { + res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{ + Location: task.File.Location, + Offset: 0, + Limit: 1024 * 1024, + }) + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + result, ok := res.(*tg.UploadFile) + if !ok { + return fmt.Errorf("unexpected type %T", res) + } + + if err := os.WriteFile(cachePath, result.Bytes, os.ModePerm); err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + + defer cleanCacheFile(cachePath) + + logger.L.Infof("Downloaded file: %s", cachePath) + + return saveFileWithRetry(task, cachePath) +} + func getProgressBar(progress float64, totalCount int) string { bar := "" barSize := 100 / totalCount diff --git a/storage/alist/alist.go b/storage/alist/alist.go index b6cdf29..75bb3e2 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -10,7 +10,6 @@ import ( "net/http" "net/url" "os" - "path" "time" "github.com/krau/SaveAny-Bot/config" @@ -20,7 +19,6 @@ import ( type Alist struct { client *http.Client token string - basePath string baseURL string loginInfo *loginRequest } @@ -105,7 +103,6 @@ func (a *Alist) refreshToken() { } func (a *Alist) Init() { - a.basePath = config.Cfg.Storage.Alist.BasePath a.baseURL = config.Cfg.Storage.Alist.URL a.client = &http.Client{ Timeout: 12 * time.Hour, @@ -128,7 +125,6 @@ func (a *Alist) Init() { } func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error { - storagePath = path.Join(a.basePath, storagePath) file, err := os.Open(filePath) if err != nil { return fmt.Errorf("failed to open file: %w", err) diff --git a/storage/local/local.go b/storage/local/local.go index b377ee4..373f3fc 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -21,7 +21,6 @@ func (l *Local) Init() { } func (l *Local) Save(ctx context.Context, filePath, storagePath string) error { - storagePath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath) absPath, err := filepath.Abs(storagePath) if err != nil { return err diff --git a/storage/storage.go b/storage/storage.go index db0b64e..1cfcb95 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -3,6 +3,8 @@ package storage import ( "context" "errors" + "path" + "path/filepath" "sync" "github.com/duke-git/lancet/v2/slice" @@ -16,7 +18,7 @@ import ( type Storage interface { Init() - Save(cttx context.Context, filePath, storagePath string) error + Save(cttx context.Context, localFilePath, storagePath string) error } var Storages = make(map[types.StorageType]Storage) @@ -47,6 +49,7 @@ func Init() { } func Save(storageType types.StorageType, ctx context.Context, filePath, storagePath string) error { + logger.L.Debugf("Saving file %s to storage: [%s] %s", filePath, storageType, storagePath) if ctx == nil { ctx = context.Background() } @@ -59,7 +62,16 @@ func Save(storageType types.StorageType, ctx context.Context, filePath, storageP wg.Add(1) go func(storage Storage) { defer wg.Done() - if err := storage.Save(ctx, filePath, storagePath); err != nil { + storageDestPath := storagePath + switch storage.(type) { + case *local.Local: + storageDestPath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath) + case *webdav.Webdav: + storageDestPath = path.Join(config.Cfg.Storage.Webdav.BasePath, storagePath) + case *alist.Alist: + storageDestPath = path.Join(config.Cfg.Storage.Alist.BasePath, storagePath) + } + if err := storage.Save(ctx, filePath, storageDestPath); err != nil { errs = append(errs, err) } }(storage) diff --git a/storage/webdav/webdav.go b/storage/webdav/webdav.go index bb29cb5..f0c3a8b 100644 --- a/storage/webdav/webdav.go +++ b/storage/webdav/webdav.go @@ -4,7 +4,6 @@ import ( "context" "os" "path" - "strings" "time" "github.com/krau/SaveAny-Bot/config" @@ -15,13 +14,11 @@ import ( type Webdav struct{} var ( - Client *gowebdav.Client - basePath string + Client *gowebdav.Client ) func (w *Webdav) Init() { webdavConfig := config.Cfg.Storage.Webdav - basePath = strings.TrimSuffix(webdavConfig.BasePath, "/") Client = gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password) if err := Client.Connect(); err != nil { logger.L.Fatalf("Failed to connect to webdav server: %v", err) @@ -31,7 +28,6 @@ func (w *Webdav) Init() { } func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error { - storagePath = path.Join(basePath, storagePath) if err := Client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil { logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err) return ErrFailedToCreateDirectory