feat: refactor file processing and storage handling with improved path management

This commit is contained in:
krau
2025-02-15 15:06:06 +08:00
parent 7692286d78
commit 3a4effab33
6 changed files with 68 additions and 51 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"time"
@@ -21,13 +22,13 @@ import (
func processPendingTask(task *types.Task) error {
logger.L.Debugf("Start processing task: %s", task.String())
destPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
absDestPath, err := filepath.Abs(destPath)
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
cacheDestPath, err := filepath.Abs(cacheDestPath)
if err != nil {
return fmt.Errorf("Failed to get absolute path: %w", err)
return fmt.Errorf("failed to get absolute path: %w", err)
}
if err := fileutil.CreateDir(filepath.Dir(absDestPath)); err != nil {
return fmt.Errorf("Failed to create directory: %w", err)
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
ctx := task.Ctx.(*ext.Context)
@@ -39,32 +40,17 @@ func processPendingTask(task *types.Task) error {
if task.StoragePath == "" {
task.StoragePath = task.File.FileName
}
switch task.Storage {
case types.Local:
task.StoragePath = filepath.Join(config.Cfg.Storage.Local.BasePath, task.StoragePath)
case types.Webdav:
task.StoragePath = path.Join(config.Cfg.Storage.Webdav.BasePath, task.StoragePath)
case types.Alist:
task.StoragePath = path.Join(config.Cfg.Storage.Alist.BasePath, task.StoragePath)
}
// process photo
if task.File.FileSize == 0 {
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
Location: task.File.Location,
Offset: 0,
Limit: 1024 * 1024,
})
if err != nil {
return fmt.Errorf("Failed to get file: %w", err)
}
result, ok := res.(*tg.UploadFile)
if !ok {
return fmt.Errorf("unexpected type %T", res)
}
if err := os.WriteFile(destPath, result.Bytes, os.ModePerm); err != nil {
return fmt.Errorf("Failed to write file: %w", err)
}
defer cleanCacheFile(destPath)
logger.L.Infof("Downloaded file: %s", destPath)
return saveFileWithRetry(task, destPath)
return processPhoto(task, cacheDestPath)
}
barTotalCount := calculateBarTotalCount(task.File.FileSize)
@@ -92,29 +78,29 @@ func processPendingTask(task *types.Task) error {
0, task.File.FileSize-1, task.File.FileSize,
progressCallback, task.File.FileSize/100)
if err != nil {
return fmt.Errorf("Failed to create reader: %w", err)
return fmt.Errorf("failed to create reader: %w", err)
}
defer readCloser.Close()
dest, err := os.Create(destPath)
dest, err := os.Create(cacheDestPath)
if err != nil {
return fmt.Errorf("Failed to create file: %w", err)
return fmt.Errorf("failed to create file: %w", err)
}
defer dest.Close()
task.StartTime = time.Now()
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
return fmt.Errorf("Failed to download file: %w", err)
return fmt.Errorf("failed to download file: %w", err)
}
defer cleanCacheFile(destPath)
defer cleanCacheFile(cacheDestPath)
logger.L.Infof("Downloaded file: %s", destPath)
logger.L.Infof("Downloaded file: %s", cacheDestPath)
ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
ID: task.ReplyMessageID,
})
return saveFileWithRetry(task, destPath)
return saveFileWithRetry(task, cacheDestPath)
}
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {

View File

@@ -5,6 +5,8 @@ import (
"os"
"time"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
@@ -12,11 +14,11 @@ import (
"github.com/krau/SaveAny-Bot/types"
)
func saveFileWithRetry(task *types.Task, destPath string) error {
func saveFileWithRetry(task *types.Task, localFilePath string) error {
for i := 0; i <= config.Cfg.Retry; i++ {
if err := storage.Save(task.Storage, task.Ctx, destPath, task.StoragePath); err != nil {
if err := storage.Save(task.Storage, task.Ctx, localFilePath, task.StoragePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("Failed to save file: %w", err)
return fmt.Errorf("failed to save file: %w", err)
}
logger.L.Errorf("Failed to save file: %s, retrying...", err)
continue
@@ -26,6 +28,32 @@ func saveFileWithRetry(task *types.Task, destPath string) error {
return nil
}
func processPhoto(task *types.Task, cachePath string) error {
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
Location: task.File.Location,
Offset: 0,
Limit: 1024 * 1024,
})
if err != nil {
return fmt.Errorf("failed to get file: %w", err)
}
result, ok := res.(*tg.UploadFile)
if !ok {
return fmt.Errorf("unexpected type %T", res)
}
if err := os.WriteFile(cachePath, result.Bytes, os.ModePerm); err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
defer cleanCacheFile(cachePath)
logger.L.Infof("Downloaded file: %s", cachePath)
return saveFileWithRetry(task, cachePath)
}
func getProgressBar(progress float64, totalCount int) string {
bar := ""
barSize := 100 / totalCount

View File

@@ -10,7 +10,6 @@ import (
"net/http"
"net/url"
"os"
"path"
"time"
"github.com/krau/SaveAny-Bot/config"
@@ -20,7 +19,6 @@ import (
type Alist struct {
client *http.Client
token string
basePath string
baseURL string
loginInfo *loginRequest
}
@@ -105,7 +103,6 @@ func (a *Alist) refreshToken() {
}
func (a *Alist) Init() {
a.basePath = config.Cfg.Storage.Alist.BasePath
a.baseURL = config.Cfg.Storage.Alist.URL
a.client = &http.Client{
Timeout: 12 * time.Hour,
@@ -128,7 +125,6 @@ func (a *Alist) Init() {
}
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
storagePath = path.Join(a.basePath, storagePath)
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)

View File

@@ -21,7 +21,6 @@ func (l *Local) Init() {
}
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
storagePath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath)
absPath, err := filepath.Abs(storagePath)
if err != nil {
return err

View File

@@ -3,6 +3,8 @@ package storage
import (
"context"
"errors"
"path"
"path/filepath"
"sync"
"github.com/duke-git/lancet/v2/slice"
@@ -16,7 +18,7 @@ import (
type Storage interface {
Init()
Save(cttx context.Context, filePath, storagePath string) error
Save(cttx context.Context, localFilePath, storagePath string) error
}
var Storages = make(map[types.StorageType]Storage)
@@ -47,6 +49,7 @@ func Init() {
}
func Save(storageType types.StorageType, ctx context.Context, filePath, storagePath string) error {
logger.L.Debugf("Saving file %s to storage: [%s] %s", filePath, storageType, storagePath)
if ctx == nil {
ctx = context.Background()
}
@@ -59,7 +62,16 @@ func Save(storageType types.StorageType, ctx context.Context, filePath, storageP
wg.Add(1)
go func(storage Storage) {
defer wg.Done()
if err := storage.Save(ctx, filePath, storagePath); err != nil {
storageDestPath := storagePath
switch storage.(type) {
case *local.Local:
storageDestPath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath)
case *webdav.Webdav:
storageDestPath = path.Join(config.Cfg.Storage.Webdav.BasePath, storagePath)
case *alist.Alist:
storageDestPath = path.Join(config.Cfg.Storage.Alist.BasePath, storagePath)
}
if err := storage.Save(ctx, filePath, storageDestPath); err != nil {
errs = append(errs, err)
}
}(storage)

View File

@@ -4,7 +4,6 @@ import (
"context"
"os"
"path"
"strings"
"time"
"github.com/krau/SaveAny-Bot/config"
@@ -15,13 +14,11 @@ import (
type Webdav struct{}
var (
Client *gowebdav.Client
basePath string
Client *gowebdav.Client
)
func (w *Webdav) Init() {
webdavConfig := config.Cfg.Storage.Webdav
basePath = strings.TrimSuffix(webdavConfig.BasePath, "/")
Client = gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
if err := Client.Connect(); err != nil {
logger.L.Fatalf("Failed to connect to webdav server: %v", err)
@@ -31,7 +28,6 @@ func (w *Webdav) Init() {
}
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
storagePath = path.Join(basePath, storagePath)
if err := Client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
return ErrFailedToCreateDirectory