From 2d2becccf64574070485906cb0f3f88948039096 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Fri, 21 Mar 2025 23:05:09 +0800 Subject: [PATCH] refactor: update storage interface to use io.Reader for Save method and remove stream implementations --- core/download.go | 73 ++++++++++++++------------- core/utils.go | 21 ++++---- storage/alist/alist.go | 104 ++------------------------------------- storage/local/local.go | 30 +++++------ storage/minio/client.go | 7 +-- storage/minio/stream.go | 92 ---------------------------------- storage/storage.go | 7 +-- storage/webdav/stream.go | 58 ---------------------- storage/webdav/webdav.go | 23 +++------ 9 files changed, 77 insertions(+), 338 deletions(-) delete mode 100644 storage/minio/stream.go delete mode 100644 storage/webdav/stream.go diff --git a/core/download.go b/core/download.go index 2271024..56802b4 100644 --- a/core/download.go +++ b/core/download.go @@ -3,6 +3,7 @@ package core import ( "context" "fmt" + "io" "path/filepath" "time" @@ -14,6 +15,7 @@ import ( "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/types" + "golang.org/x/sync/errgroup" ) func processPendingTask(task *types.Task) error { @@ -40,10 +42,6 @@ func processPendingTask(task *types.Task) error { } task.StoragePath = taskStorage.JoinStoragePath(*task) - if task.File.FileSize == 0 { - return processPhoto(task, taskStorage, cacheDestPath) - } - ctx, ok := task.Ctx.(*ext.Context) if !ok { return fmt.Errorf("context is not *ext.Context: %T", task.Ctx) @@ -52,38 +50,47 @@ func processPendingTask(task *types.Task) error { cancelCtx, cancel := context.WithCancel(ctx) task.Cancel = cancel - downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) + if task.File.FileSize == 0 { + return processPhoto(task, taskStorage) + } + + downloadBuilder := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) - taskStreamStorage, isStreamStorage := taskStorage.(storage.StreamStorage) if config.Cfg.Stream { - if !isStreamStorage { - common.Log.Warnf("存储 %s 不支持流式上传", taskStorage.Name()) - } else { - text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, - ReplyMarkup: getCancelTaskMarkup(task), - }) - uploadStream, err := taskStreamStorage.NewUploadStream(cancelCtx, task.StoragePath) - if err != nil { - return fmt.Errorf("创建上传流失败: %w", err) + + text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) + ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + ReplyMarkup: getCancelTaskMarkup(task), + }) + + pr, pw := io.Pipe() + defer pr.Close() + + task.StartTime = time.Now() + progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) + + progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback) + + eg, uploadCtx := errgroup.WithContext(cancelCtx) + + eg.Go(func() error { + return taskStorage.Save(uploadCtx, pr, task.StoragePath) + }) + eg.Go(func() error { + _, err := downloadBuilder.Stream(uploadCtx, progressStream) + if closeErr := pw.CloseWithError(err); closeErr != nil { + common.Log.Errorf("Failed to close pipe writer: %v", closeErr) } - defer uploadStream.Close() - - task.StartTime = time.Now() - progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) - - progressStream := NewProgressStream(uploadStream, task.File.FileSize, progressCallback) - - _, err = downloadBuider.Stream(cancelCtx, progressStream) - if err != nil { - return fmt.Errorf("下载文件失败: %w", err) - } - common.Log.Infof("Uploaded file: %s", task.StoragePath) - return nil + return err + }) + if err := eg.Wait(); err != nil { + return err } + + return nil } text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) @@ -101,7 +108,7 @@ func processPendingTask(task *types.Task) error { } defer dest.Close() task.StartTime = time.Now() - _, err = downloadBuider.Parallel(cancelCtx, dest) + _, err = downloadBuilder.Parallel(cancelCtx, dest) if err != nil { return fmt.Errorf("下载文件失败: %w", err) } diff --git a/core/utils.go b/core/utils.go index 8a8130b..ad882d7 100644 --- a/core/utils.go +++ b/core/utils.go @@ -1,6 +1,7 @@ package core import ( + "bytes" "context" "fmt" "io" @@ -20,12 +21,16 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, localFilePath string) error { +func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, cacheFilePath string) error { for i := 0; i <= config.Cfg.Retry; i++ { if err := ctx.Err(); err != nil { return fmt.Errorf("context canceled while saving file: %w", err) } - if err := taskStorage.Save(ctx, localFilePath, task.StoragePath); err != nil { + file, err := os.Open(cacheFilePath) + if err != nil { + return fmt.Errorf("failed to open cache file: %w", err) + } + if err := taskStorage.Save(ctx, file, task.StoragePath); err != nil { if i == config.Cfg.Retry { return fmt.Errorf("failed to save file: %w", err) } @@ -42,7 +47,7 @@ func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storag return nil } -func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath string) error { +func processPhoto(task *types.Task, taskStorage storage.Storage) error { res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{ Location: task.File.Location, Offset: 0, @@ -57,15 +62,9 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin return fmt.Errorf("unexpected type %T", res) } - if err := os.WriteFile(cachePath, result.Bytes, os.ModePerm); err != nil { - return fmt.Errorf("failed to write file: %w", err) - } + common.Log.Infof("Downloaded photo: %s", task.FileName()) - defer cleanCacheFile(cachePath) - - common.Log.Infof("Downloaded file: %s", cachePath) - - return saveFileWithRetry(task.Ctx, task, taskStorage, cachePath) + return taskStorage.Save(task.Ctx, bytes.NewReader(result.Bytes), task.StoragePath) } func cleanCacheFile(destPath string) { diff --git a/storage/alist/alist.go b/storage/alist/alist.go index cb84bb7..a2ca632 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -7,9 +7,7 @@ import ( "io" "net/http" "net/url" - "os" "path" - "sync" "time" "github.com/krau/SaveAny-Bot/common" @@ -98,28 +96,16 @@ func (a *Alist) Name() string { return a.config.Name } -func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error { - common.Log.Infof("Saving file %s to %s", filePath, storagePath) - file, err := os.Open(filePath) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) - } - defer file.Close() +func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) error { + common.Log.Infof("Saving file to %s", storagePath) - filestat, err := file.Stat() - if err != nil { - return fmt.Errorf("failed to get file stats: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", file) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", reader) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", a.token) req.Header.Set("File-Path", url.PathEscape(storagePath)) - // req.Header.Set("As-Task", "true") req.Header.Set("Content-Type", "application/octet-stream") - req.ContentLength = filestat.Size() resp, err := a.client.Do(req) if err != nil { @@ -151,87 +137,3 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error { func (a *Alist) JoinStoragePath(task types.Task) string { return path.Join(a.config.BasePath, task.StoragePath) } - -type uploadStream struct { - ctx context.Context - client *http.Client - token string - storagePath string - baseURL string - pr *io.PipeReader - pw *io.PipeWriter - errChan chan error - once sync.Once -} - -func (us *uploadStream) Write(p []byte) (int, error) { - return us.pw.Write(p) -} - -func (us *uploadStream) Close() error { - var uploadErr error - us.once.Do(func() { - if err := us.pw.Close(); err != nil { - uploadErr = fmt.Errorf("failed to close pipe writer: %w", err) - return - } - - if err := <-us.errChan; err != nil { - uploadErr = err - } - }) - return uploadErr -} - -func (a *Alist) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) { - if a.token == "" { - if err := a.getToken(); err != nil { - return nil, fmt.Errorf("not logged in to Alist: %w", err) - } - } - - pr, pw := io.Pipe() - - us := &uploadStream{ - ctx: ctx, - client: a.client, - token: a.token, - storagePath: storagePath, - baseURL: a.baseURL, - pr: pr, - pw: pw, - errChan: make(chan error, 1), - } - - go func() { - defer close(us.errChan) - - req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", pr) - if err != nil { - us.errChan <- fmt.Errorf("failed to create request: %w", err) - return - } - - req.Header.Set("Authorization", a.token) - req.Header.Set("File-Path", url.PathEscape(storagePath)) - // req.Header.Set("As-Task", "true") - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := a.client.Do(req) - if err != nil { - us.errChan <- fmt.Errorf("failed to send request: %w", err) - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - us.errChan <- fmt.Errorf("failed to upload file, status code: %d, response: %s", resp.StatusCode, string(body)) - return - } - - us.errChan <- nil - }() - - return us, nil -} diff --git a/storage/local/local.go b/storage/local/local.go index 721ef20..5a1c11a 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -41,8 +41,13 @@ func (l *Local) Name() string { return l.config.Name } -func (l *Local) Save(ctx context.Context, filePath, storagePath string) error { - common.Log.Infof("Saving file %s to %s", filePath, storagePath) +func (l *Local) JoinStoragePath(task types.Task) string { + return filepath.Join(l.config.BasePath, task.StoragePath) +} + +func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) error { + common.Log.Infof("Saving file to %s", storagePath) + absPath, err := filepath.Abs(storagePath) if err != nil { return err @@ -50,24 +55,11 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error { if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil { return err } - return fileutil.CopyFile(filePath, storagePath) -} - -func (l *Local) JoinStoragePath(task types.Task) string { - return filepath.Join(l.config.BasePath, task.StoragePath) -} - -func (l *Local) NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error) { - absPath, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil { - return nil, err - } file, err := os.Create(absPath) if err != nil { - return nil, err + return err } - return file, nil + defer file.Close() + _, err = io.Copy(file, r) + return err } diff --git a/storage/minio/client.go b/storage/minio/client.go index 18e2c40..bfd0e54 100644 --- a/storage/minio/client.go +++ b/storage/minio/client.go @@ -3,6 +3,7 @@ package minio import ( "context" "fmt" + "io" "path" "github.com/krau/SaveAny-Bot/common" @@ -59,10 +60,10 @@ func (m *Minio) JoinStoragePath(task types.Task) string { return path.Join(m.config.BasePath, task.StoragePath) } -func (m *Minio) Save(ctx context.Context, localFilePath, storagePath string) error { - common.Log.Infof("Saving file %s to %s", localFilePath, storagePath) +func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error { + common.Log.Infof("Saving file from reader to %s", storagePath) - _, err := m.client.FPutObject(ctx, m.config.BucketName, storagePath, localFilePath, minio.PutObjectOptions{}) + _, err := m.client.PutObject(ctx, m.config.BucketName, storagePath, r, -1, minio.PutObjectOptions{}) if err != nil { return fmt.Errorf("failed to upload file to minio: %w", err) } diff --git a/storage/minio/stream.go b/storage/minio/stream.go deleted file mode 100644 index 9aee05c..0000000 --- a/storage/minio/stream.go +++ /dev/null @@ -1,92 +0,0 @@ -package minio - -import ( - "context" - "fmt" - "io" - - "github.com/krau/SaveAny-Bot/common" - "github.com/minio/minio-go/v7" -) - -type MinioWriter struct { - pipeWriter *io.PipeWriter - done chan error - path string - ctx context.Context - closed bool -} - -func (w *MinioWriter) Write(p []byte) (n int, err error) { - select { - case <-w.ctx.Done(): - return 0, w.ctx.Err() - default: - return w.pipeWriter.Write(p) - } -} - -func (w *MinioWriter) Close() error { - if w.closed { - return nil - } - w.closed = true - - if err := w.pipeWriter.Close(); err != nil { - return fmt.Errorf("failed to close pipe writer: %w", err) - } - - select { - case err := <-w.done: - if err != nil { - return fmt.Errorf("upload failed: %w", err) - } - return nil - case <-w.ctx.Done(): - return fmt.Errorf("upload cancelled: %w", w.ctx.Err()) - } -} - -func (m *Minio) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) { - common.Log.Infof("Creating upload stream for %s", storagePath) - - uploadCtx, cancel := context.WithCancel(ctx) - pipeReader, pipeWriter := io.Pipe() - done := make(chan error, 1) - - go func() { - defer func() { - if r := recover(); r != nil { - done <- fmt.Errorf("panic during upload: %v", r) - } - pipeReader.Close() - cancel() - }() - - info, err := m.client.PutObject( - uploadCtx, - m.config.BucketName, - storagePath, - pipeReader, - -1, - minio.PutObjectOptions{}, - ) - - if err != nil { - common.Log.Errorf("Failed to upload to %s: %v", storagePath, err) - done <- err - return - } - - common.Log.Infof("uploaded %d bytes to %s", info.Size, storagePath) - done <- nil - }() - - return &MinioWriter{ - pipeWriter: pipeWriter, - done: done, - path: storagePath, - ctx: uploadCtx, - closed: false, - }, nil -} diff --git a/storage/storage.go b/storage/storage.go index 3b304b5..80b1b91 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -20,12 +20,7 @@ type Storage interface { Type() types.StorageType Name() string JoinStoragePath(task types.Task) string - Save(cttx context.Context, localFilePath, storagePath string) error -} - -type StreamStorage interface { - Storage - NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error) + Save(ctx context.Context, reader io.Reader, storagePath string) error } var Storages = make(map[string]Storage) diff --git a/storage/webdav/stream.go b/storage/webdav/stream.go deleted file mode 100644 index e8bbb0a..0000000 --- a/storage/webdav/stream.go +++ /dev/null @@ -1,58 +0,0 @@ -package webdav - -import ( - "context" - "fmt" - "io" - "path" - - "github.com/krau/SaveAny-Bot/common" -) - -type WebdavWriter struct { - pipeWriter *io.PipeWriter - done chan error - path string -} - -func (w *WebdavWriter) Write(p []byte) (n int, err error) { - return w.pipeWriter.Write(p) -} - -func (w *WebdavWriter) Close() error { - if err := w.pipeWriter.Close(); err != nil { - return err - } - if err := <-w.done; err != nil { - return fmt.Errorf("upload failed: %w", err) - } - - return nil -} - -func (w *Webdav) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) { - if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil { - common.Log.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err) - return nil, ErrFailedToCreateDirectory - } - pipeReader, pipeWriter := io.Pipe() - done := make(chan error, 1) - go func() { - defer func() { - if err := recover(); err != nil { - done <- fmt.Errorf("panic during upload: %v", err) - } - }() - - err := w.client.WriteFile(ctx, storagePath, pipeReader) - - pipeReader.Close() - done <- err - }() - - return &WebdavWriter{ - pipeWriter: pipeWriter, - done: done, - path: storagePath, - }, nil -} diff --git a/storage/webdav/webdav.go b/storage/webdav/webdav.go index d89e5f4..d0e5ff2 100644 --- a/storage/webdav/webdav.go +++ b/storage/webdav/webdav.go @@ -3,8 +3,8 @@ package webdav import ( "context" "fmt" + "io" "net/http" - "os" "path" "time" @@ -41,26 +41,19 @@ func (w *Webdav) Name() string { return w.config.Name } -func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error { - common.Log.Infof("Saving file %s to %s", filePath, storagePath) +func (w *Webdav) JoinStoragePath(task types.Task) string { + return path.Join(w.config.BasePath, task.StoragePath) +} + +func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) error { + common.Log.Infof("Saving file to %s", storagePath) if err := w.client.MkDir(ctx, path.Dir(storagePath)); err != nil { common.Log.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err) return ErrFailedToCreateDirectory } - file, err := os.Open(filePath) - if err != nil { - common.Log.Errorf("Failed to open file %s: %v", filePath, err) - return err - } - defer file.Close() - - if err := w.client.WriteFile(ctx, storagePath, file); err != nil { + if err := w.client.WriteFile(ctx, storagePath, r); err != nil { common.Log.Errorf("Failed to write file %s: %v", storagePath, err) return ErrFailedToWriteFile } return nil } - -func (w *Webdav) JoinStoragePath(task types.Task) string { - return path.Join(w.config.BasePath, task.StoragePath) -}