From 8975589c43146fc0deb9e359381c0e3aeeb61bc8 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:16:45 +0800 Subject: [PATCH] refactor: file download process and enhance progress tracking --- core/core.go | 100 ----------------------------------------------- core/download.go | 89 +++++++++++++++++++++++++++++++++++++++++ core/reader.go | 66 +++++++++++++++---------------- core/utils.go | 31 +++++++++++++++ 4 files changed, 153 insertions(+), 133 deletions(-) create mode 100644 core/download.go diff --git a/core/core.go b/core/core.go index c1434cb..c91f043 100644 --- a/core/core.go +++ b/core/core.go @@ -4,115 +4,15 @@ import ( "context" "errors" "fmt" - "io" - "os" - "path" - "path/filepath" - "time" - - "github.com/gabriel-vasile/mimetype" "github.com/celestix/gotgproto/ext" - "github.com/duke-git/lancet/v2/fileutil" "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/bot" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/logger" "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/types" ) -func processPendingTask(task *types.Task) error { - logger.L.Debugf("Start processing task: %s", task.String()) - if task.FileName() == "" { - task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash()) - } - cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName()) - cacheDestPath, err := filepath.Abs(cacheDestPath) - if err != nil { - return fmt.Errorf("处理路径失败: %w", err) - } - if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil { - return fmt.Errorf("创建目录失败: %w", err) - } - - if task.StoragePath == "" { - task.StoragePath = task.File.FileName - } - - taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName) - if err != nil { - return err - } - task.StoragePath = taskStorage.JoinStoragePath(*task) - - if task.File.FileSize == 0 { - return processPhoto(task, taskStorage, cacheDestPath) - } - - ctx := task.Ctx.(*ext.Context) - - barTotalCount := calculateBarTotalCount(task.File.FileSize) - - progressCallback := func(bytesRead, contentLength int64) { - progress := float64(bytesRead) / float64(contentLength) * 100 - logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress) - if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 { - return - } - text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress) - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, - }) - } - - text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0) - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, - }) - - readCloser, err := NewTelegramReader(task.Ctx, bot.Client, &task.File.Location, - 0, task.File.FileSize-1, task.File.FileSize, - progressCallback, task.File.FileSize/100) - if err != nil { - return fmt.Errorf("创建下载失败: %w", err) - } - defer readCloser.Close() - - dest, err := os.Create(cacheDestPath) - if err != nil { - return fmt.Errorf("创建文件失败: %w", err) - } - defer dest.Close() - task.StartTime = time.Now() - if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil { - return fmt.Errorf("下载文件失败: %w", err) - } - defer cleanCacheFile(cacheDestPath) - if path.Ext(task.FileName()) == "" { - mimeType, err := mimetype.DetectFile(cacheDestPath) - if err != nil { - logger.L.Errorf("Failed to detect mime type: %s", err) - } else { - task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension()) - task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension()) - } - } - - logger.L.Infof("Downloaded file: %s", cacheDestPath) - ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), - ID: task.ReplyMessageID, - }) - - return saveFileWithRetry(task, taskStorage, cacheDestPath) -} - func worker(queue *queue.TaskQueue, semaphore chan struct{}) { for { semaphore <- struct{}{} diff --git a/core/download.go b/core/download.go new file mode 100644 index 0000000..54385fa --- /dev/null +++ b/core/download.go @@ -0,0 +1,89 @@ +package core + +import ( + "fmt" + "io" + "os" + "path/filepath" + "time" + + "github.com/celestix/gotgproto/ext" + "github.com/duke-git/lancet/v2/fileutil" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/bot" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/logger" + "github.com/krau/SaveAny-Bot/storage" + "github.com/krau/SaveAny-Bot/types" +) + +func processPendingTask(task *types.Task) error { + logger.L.Debugf("Start processing task: %s", task.String()) + if task.FileName() == "" { + task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash()) + } + cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName()) + cacheDestPath, err := filepath.Abs(cacheDestPath) + if err != nil { + return fmt.Errorf("处理路径失败: %w", err) + } + if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil { + return fmt.Errorf("创建目录失败: %w", err) + } + + if task.StoragePath == "" { + task.StoragePath = task.File.FileName + } + + taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName) + if err != nil { + return err + } + task.StoragePath = taskStorage.JoinStoragePath(*task) + + if task.File.FileSize == 0 { + return processPhoto(task, taskStorage, cacheDestPath) + } + + ctx, ok := task.Ctx.(*ext.Context) + if !ok { + return fmt.Errorf("context is not *ext.Context: %T", task.Ctx) + } + + barTotalCount := calculateBarTotalCount(task.File.FileSize) + text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0) + ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + }) + progressCallback := buildProgressCallback(ctx, task, barTotalCount) + readCloser, err := NewTelegramReader(ctx, bot.Client, &task.File.Location, + 0, task.File.FileSize-1, task.File.FileSize, + progressCallback, task.File.FileSize/100) + if err != nil { + return fmt.Errorf("创建下载失败: %w", err) + } + defer readCloser.Close() + + dest, err := os.Create(cacheDestPath) + if err != nil { + return fmt.Errorf("创建文件失败: %w", err) + } + defer dest.Close() + task.StartTime = time.Now() + if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil { + return fmt.Errorf("下载文件失败: %w", err) + } + defer cleanCacheFile(cacheDestPath) + + fixTaskFileExt(task, cacheDestPath) + + logger.L.Infof("Downloaded file: %s", cacheDestPath) + ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()), + ID: task.ReplyMessageID, + }) + + return saveFileWithRetry(task, taskStorage, cacheDestPath) +} diff --git a/core/reader.go b/core/reader.go index 091b4fa..4199ed7 100644 --- a/core/reader.go +++ b/core/reader.go @@ -16,7 +16,7 @@ type telegramReader struct { location *tg.InputFileLocationClass bytesread int64 chunkSize int64 - i int64 + copied int64 contentLength int64 start int64 end int64 @@ -32,12 +32,12 @@ func (*telegramReader) Close() error { return nil } -func (r *telegramReader) Read(p []byte) (n int, err error) { +func (r *telegramReader) Read(dst []byte) (n int, err error) { if r.bytesread == r.contentLength { return 0, io.EOF } - if r.i >= int64(len(r.buffer)) { + if r.copied >= int64(len(r.buffer)) { r.buffer, err = r.next() if err != nil { return 0, err @@ -50,10 +50,10 @@ func (r *telegramReader) Read(p []byte) (n int, err error) { } } - r.i = 0 + r.copied = 0 } - n = copy(p, r.buffer[r.i:]) - r.i += int64(n) + n = copy(dst, r.buffer[r.copied:]) + r.copied += int64(n) r.bytesread += int64(n) if r.progressCallback != nil && (r.bytesread-r.lastProgress >= r.callbackInterval || r.bytesread == r.contentLength) { @@ -64,33 +64,6 @@ func (r *telegramReader) Read(p []byte) (n int, err error) { return n, nil } -func NewTelegramReader( - ctx context.Context, - client *gotgproto.Client, - location *tg.InputFileLocationClass, - start int64, - end int64, - contentLength int64, - progressCallback func(bytesRead, contentLength int64), - callbackInterval int64, -) (io.ReadCloser, error) { - - r := &telegramReader{ - ctx: ctx, - location: location, - client: client, - start: start, - end: end, - chunkSize: int64(1024 * 1024), - contentLength: contentLength, - progressCallback: progressCallback, - callbackInterval: callbackInterval, - } - - r.next = r.partStream() - return r, nil -} - func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) { var lastError error for i := 0; i < config.Cfg.Retry; i++ { @@ -152,3 +125,30 @@ func (r *telegramReader) partStream() func() ([]byte, error) { } return readData } + +func NewTelegramReader( + ctx context.Context, + client *gotgproto.Client, + location *tg.InputFileLocationClass, + start int64, + end int64, + contentLength int64, + progressCallback func(bytesRead, contentLength int64), + callbackInterval int64, +) (io.ReadCloser, error) { + + r := &telegramReader{ + ctx: ctx, + location: location, + client: client, + start: start, + end: end, + chunkSize: int64(1024 * 1024), + contentLength: contentLength, + progressCallback: progressCallback, + callbackInterval: callbackInterval, + } + + r.next = r.partStream() + return r, nil +} diff --git a/core/utils.go b/core/utils.go index 25a9f8c..dc0d018 100644 --- a/core/utils.go +++ b/core/utils.go @@ -3,8 +3,11 @@ package core import ( "fmt" "os" + "path" "time" + "github.com/celestix/gotgproto/ext" + "github.com/gabriel-vasile/mimetype" "github.com/gotd/td/telegram/message/entity" "github.com/gotd/td/telegram/message/styling" "github.com/gotd/td/tg" @@ -125,3 +128,31 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i } return entityBuilder.Complete() } + +func buildProgressCallback(ctx *ext.Context, task *types.Task, barTotalCount int) func(bytesRead, contentLength int64) { + return func(bytesRead, contentLength int64) { + progress := float64(bytesRead) / float64(contentLength) * 100 + logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress) + if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 { + return + } + text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress) + ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + }) + } +} + +func fixTaskFileExt(task *types.Task, localFilePath string) { + if path.Ext(task.FileName()) == "" { + mimeType, err := mimetype.DetectFile(localFilePath) + if err != nil { + logger.L.Errorf("Failed to detect mime type: %s", err) + } else { + task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension()) + task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension()) + } + } +}