From 8e2dd3715595e247bb24924bc44de15bac96021e Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Fri, 28 Feb 2025 11:09:24 +0800 Subject: [PATCH] feat: add stream upload support and related configurations --- config/viper.go | 1 + core/download.go | 51 ++++++++++++++++++++++-- storage/alist/alist.go | 86 ++++++++++++++++++++++++++++++++++++++++ storage/local/local.go | 16 ++++++++ storage/storage.go | 6 +++ storage/webdav/stream.go | 50 +++++++++++++++++++++++ 6 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 storage/webdav/stream.go diff --git a/config/viper.go b/config/viper.go index 7ec3c0f..61fad69 100644 --- a/config/viper.go +++ b/config/viper.go @@ -13,6 +13,7 @@ type Config struct { Retry int `toml:"retry" mapstructure:"retry"` NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"` Threads int `toml:"threads" mapstructure:"threads" json:"threads"` + Stream bool `toml:"stream" mapstructure:"stream" json:"stream"` Users []userConfig `toml:"users" mapstructure:"users" json:"users"` diff --git a/core/download.go b/core/download.go index cce4308..d849ae4 100644 --- a/core/download.go +++ b/core/download.go @@ -8,6 +8,8 @@ import ( "github.com/celestix/gotgproto/ext" "github.com/duke-git/lancet/v2/fileutil" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/bot" "github.com/krau/SaveAny-Bot/config" @@ -52,6 +54,50 @@ func processPendingTask(task *types.Task) error { cancelCtx, cancel := context.WithCancel(ctx) task.Cancel = cancel + downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) + + // TODO: show progress for stream storage + taskStreamStorage, isStreamStorage := taskStorage.(storage.StreamStorage) + if config.Cfg.Stream { + if !isStreamStorage { + logger.L.Warnf("存储 %s 不支持流式上传", taskStorage.Name()) + } else { + entityBuilder := entity.Builder{} + text := fmt.Sprintf("正在处理下载任务 (流式)\n文件名: %s\n保存路径: %s", + task.FileName(), + fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath), + ) + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("正在处理下载任务 (流式)\n文件名: "), + styling.Code(task.FileName()), + styling.Plain("\n保存路径: "), + styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)), + ); err != nil { + logger.L.Errorf("Failed to build entities: %s", err) + } else { + text, entities = entityBuilder.Complete() + } + ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + ReplyMarkup: getCancelTaskMarkup(task), + }) + uploadStream, err := taskStreamStorage.NewUploadStream(cancelCtx, task.StoragePath) + if err != nil { + return fmt.Errorf("创建上传流失败: %w", err) + } + defer uploadStream.Close() + _, err = downloadBuider.Stream(cancelCtx, uploadStream) + if err != nil { + return fmt.Errorf("下载文件失败: %w", err) + } + logger.L.Infof("Uploaded file: %s", task.StoragePath) + return nil + } + } + text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ Message: text, @@ -59,20 +105,18 @@ func processPendingTask(task *types.Task) error { ID: task.ReplyMessageID, ReplyMarkup: getCancelTaskMarkup(task), }) - progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) + progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback) if err != nil { return fmt.Errorf("创建文件失败: %w", err) } defer dest.Close() task.StartTime = time.Now() - downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) _, err = downloadBuider.Parallel(cancelCtx, dest) if err != nil { return fmt.Errorf("下载文件失败: %w", err) } - defer cleanCacheFile(cacheDestPath) fixTaskFileExt(task, cacheDestPath) @@ -84,4 +128,5 @@ func processPendingTask(task *types.Task) error { }) return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath) + } diff --git a/storage/alist/alist.go b/storage/alist/alist.go index bd253ea..d15c7c5 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "path" + "sync" "time" "github.com/krau/SaveAny-Bot/config" @@ -150,3 +151,88 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error { func (a *Alist) JoinStoragePath(task types.Task) string { return path.Join(a.config.BasePath, task.StoragePath) } + +type uploadStream struct { + ctx context.Context + client *http.Client + token string + storagePath string + baseURL string + pr *io.PipeReader + pw *io.PipeWriter + errChan chan error + once sync.Once +} + +func (us *uploadStream) Write(p []byte) (int, error) { + return us.pw.Write(p) +} + +func (us *uploadStream) Close() error { + var uploadErr error + us.once.Do(func() { + if err := us.pw.Close(); err != nil { + uploadErr = fmt.Errorf("failed to close pipe writer: %w", err) + return + } + + if err := <-us.errChan; err != nil { + uploadErr = err + } + }) + return uploadErr +} + +func (a *Alist) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) { + if a.token == "" { + if err := a.getToken(); err != nil { + return nil, fmt.Errorf("not logged in to Alist: %w", err) + } + } + + pr, pw := io.Pipe() + + // 创建上传流对象 + us := &uploadStream{ + ctx: ctx, + client: a.client, + token: a.token, + storagePath: storagePath, + baseURL: a.baseURL, + pr: pr, + pw: pw, + errChan: make(chan error, 1), + } + + go func() { + defer close(us.errChan) + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", pr) + if err != nil { + us.errChan <- fmt.Errorf("failed to create request: %w", err) + return + } + + req.Header.Set("Authorization", a.token) + req.Header.Set("File-Path", url.PathEscape(storagePath)) + req.Header.Set("As-Task", "true") + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := a.client.Do(req) + if err != nil { + us.errChan <- fmt.Errorf("failed to send request: %w", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + us.errChan <- fmt.Errorf("failed to upload file, status code: %d, response: %s", resp.StatusCode, string(body)) + return + } + + us.errChan <- nil + }() + + return us, nil +} diff --git a/storage/local/local.go b/storage/local/local.go index e550d42..cc1e48e 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -3,6 +3,7 @@ package local import ( "context" "fmt" + "io" "os" "path/filepath" @@ -55,3 +56,18 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error { func (l *Local) JoinStoragePath(task types.Task) string { return filepath.Join(l.config.BasePath, task.StoragePath) } + +func (l *Local) NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error) { + absPath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil { + return nil, err + } + file, err := os.Create(absPath) + if err != nil { + return nil, err + } + return file, nil +} diff --git a/storage/storage.go b/storage/storage.go index 3fb7a07..ee6cd2b 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -3,6 +3,7 @@ package storage import ( "context" "fmt" + "io" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/logger" @@ -20,6 +21,11 @@ type Storage interface { Save(cttx context.Context, localFilePath, storagePath string) error } +type StreamStorage interface { + Storage + NewUploadStream(ctx context.Context, path string) (io.WriteCloser, error) +} + var Storages = make(map[string]Storage) var UserStorages = make(map[int64][]Storage) diff --git a/storage/webdav/stream.go b/storage/webdav/stream.go new file mode 100644 index 0000000..d490c3a --- /dev/null +++ b/storage/webdav/stream.go @@ -0,0 +1,50 @@ +package webdav + +// TODO: gowebdav's WriteStream impl cause high memory usage, need to implement our own WriteStream +// type WebdavWriter struct { +// pipeWriter *io.PipeWriter +// done chan error +// path string +// } + +// func (w *WebdavWriter) Write(p []byte) (n int, err error) { +// return w.pipeWriter.Write(p) +// } + +// func (w *WebdavWriter) Close() error { +// if err := w.pipeWriter.Close(); err != nil { +// return err +// } +// if err := <-w.done; err != nil { +// return fmt.Errorf("upload failed: %w", err) +// } + +// return nil +// } + +// func (w *Webdav) NewUploadStream(ctx context.Context, storagePath string) (io.WriteCloser, error) { +// if err := w.client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil { +// logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err) +// return nil, ErrFailedToCreateDirectory +// } +// pipeReader, pipeWriter := io.Pipe() +// done := make(chan error, 1) +// go func() { +// defer func() { +// if err := recover(); err != nil { +// done <- fmt.Errorf("panic during upload: %v", err) +// } +// }() + +// err := w.client.WriteStream(storagePath, pipeReader, os.ModePerm) + +// pipeReader.Close() +// done <- err +// }() + +// return &WebdavWriter{ +// pipeWriter: pipeWriter, +// done: done, +// path: storagePath, +// }, nil +// }