From b05d86509c31c881e882bb0c4472a0133003a957 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Sat, 17 Jan 2026 14:57:03 +0800 Subject: [PATCH] feat: basic aria2 integration --- client/bot/handlers/add_task.go | 9 + client/bot/handlers/dl.go | 30 ++- client/bot/handlers/utils/msgelem/storage.go | 2 + client/bot/handlers/utils/shortcut/aria2.go | 65 ++++++ common/i18n/i18nk/keys.go | 6 + common/i18n/locale/en.yaml | 12 ++ common/i18n/locale/zh-Hans.yaml | 6 + config.example.toml | 11 + config/viper.go | 7 +- core/tasks/aria2dl/execute.go | 208 ++++++++++++++++++ core/tasks/aria2dl/progress.go | 189 +++++++++++++++++ core/tasks/aria2dl/task.go | 61 ++++++ core/tasks/aria2dl/task_test.go | 209 +++++++++++++++++++ pkg/enums/tasktype/tasktype.go | 2 +- pkg/enums/tasktype/tasktype_enum.go | 5 + pkg/tcbdata/data.go | 2 + 16 files changed, 809 insertions(+), 15 deletions(-) create mode 100644 client/bot/handlers/utils/shortcut/aria2.go create mode 100644 core/tasks/aria2dl/execute.go create mode 100644 core/tasks/aria2dl/progress.go create mode 100644 core/tasks/aria2dl/task.go create mode 100644 core/tasks/aria2dl/task_test.go diff --git a/client/bot/handlers/add_task.go b/client/bot/handlers/add_task.go index 7e8ed18..6a04f22 100644 --- a/client/bot/handlers/add_task.go +++ b/client/bot/handlers/add_task.go @@ -90,6 +90,15 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error { shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID) case tasktype.TaskTypeDirectlinks: shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID) + case tasktype.TaskTypeAria2: + client := GetAria2Client() + if client == nil { + ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, i18n.T(i18nk.BotMsgAria2ErrorAria2ClientInitFailed, map[string]any{ + "Error": "aria2 client not initialized", + }))) + return dispatcher.EndGroups + } + shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID) default: return fmt.Errorf("unexcept task type: %s", data.TaskType) } diff --git a/client/bot/handlers/dl.go b/client/bot/handlers/dl.go index 0aefbb0..518886d 100644 --- a/client/bot/handlers/dl.go +++ b/client/bot/handlers/dl.go @@ -58,6 +58,11 @@ var aria2ClientInitOnce sync.Once var aria2ClientInitErr error var aria2Client *aria2.Client +// GetAria2Client returns the shared aria2 client instance +func GetAria2Client() *aria2.Client { + return aria2Client +} + func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error { if !config.C().Aria2.Enable { ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil) @@ -78,7 +83,9 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error { ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil) return nil } - logger.Debug("Adding aria2 download", "links", links) + logger.Debug("Preparing aria2 download", "links", links) + + // Initialize aria2 client to check connection aria2ClientInitOnce.Do(func() { aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret) }) @@ -89,17 +96,18 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error { })), nil) return nil } - gid, err := aria2Client.AddURI(ctx, links, nil) + + // Build storage selection keyboard (don't add to aria2 yet) + markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{ + TaskType: tasktype.TaskTypeAria2, + Aria2URIs: links, + }) if err != nil { - logger.Error("Failed to add aria2 download", "error", err) - ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{ - "Error": err.Error(), - })), nil) - return nil + return err } - logger.Info("Aria2 download added", "gid", gid) - ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoAria2DownloadAdded, map[string]any{ - "GID": gid, - })), nil) + + ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoSelectStorage)), &ext.ReplyOpts{ + Markup: markup, + }) return nil } diff --git a/client/bot/handlers/utils/msgelem/storage.go b/client/bot/handlers/utils/msgelem/storage.go index c037fc4..efdb59c 100644 --- a/client/bot/handlers/utils/msgelem/storage.go +++ b/client/bot/handlers/utils/msgelem/storage.go @@ -49,6 +49,8 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add) ParsedItem: adddata.ParsedItem, DirectLinks: adddata.DirectLinks, + + Aria2URIs: adddata.Aria2URIs, } dataid := xid.New().String() err := cache.Set(dataid, data) diff --git a/client/bot/handlers/utils/shortcut/aria2.go b/client/bot/handlers/utils/shortcut/aria2.go new file mode 100644 index 0000000..a564087 --- /dev/null +++ b/client/bot/handlers/utils/shortcut/aria2.go @@ -0,0 +1,65 @@ +package shortcut + +import ( + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/common/i18n/i18nk" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/core/tasks/aria2dl" + "github.com/krau/SaveAny-Bot/pkg/aria2" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +func CreateAndAddAria2TaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, uris []string, aria2Client *aria2.Client, msgID int, userID int64) error { + logger := log.FromContext(ctx) + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) + + // Now add to aria2 after user selected storage + logger.Infof("Adding download to aria2, uris type: %T, value: %+v", uris, uris) + + // Ensure uris is valid + if len(uris) == 0 { + logger.Error("URIs list is empty") + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: msgID, + Message: i18n.T(i18nk.BotMsgDlErrorNoValidLinks, nil), + }) + return dispatcher.EndGroups + } + + gid, err := aria2Client.AddURI(ctx, uris, nil) + if err != nil { + logger.Errorf("Failed to add aria2 download: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: msgID, + Message: i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{ + "Error": err.Error(), + }), + }) + return dispatcher.EndGroups + } + logger.Infof("Aria2 download added with GID: %s", gid) + + // Create task with the GID + task := aria2dl.NewTask(xid.New().String(), injectCtx, gid, uris, aria2Client, stor, stor.JoinStoragePath(dirPath), aria2dl.NewProgress(msgID, userID)) + if err := core.AddTask(injectCtx, task); err != nil { + logger.Errorf("Failed to add task: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: msgID, + Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ + "Error": err.Error(), + }), + }) + return dispatcher.EndGroups + } + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: msgID, + Message: i18n.T(i18nk.BotMsgCommonInfoTaskAdded, nil), + }) + return dispatcher.EndGroups +} diff --git a/common/i18n/i18nk/keys.go b/common/i18n/i18nk/keys.go index afcbdbc..3129f79 100644 --- a/common/i18n/i18nk/keys.go +++ b/common/i18n/i18nk/keys.go @@ -9,6 +9,7 @@ const ( BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled" BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download" BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added" + BotMsgAria2InfoSelectStorage Key = "bot.msg.aria2.info_select_storage" BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed" BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested" BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task" @@ -127,15 +128,20 @@ const ( BotMsgParserInfoInstallPluginSuccess Key = "bot.msg.parser.info_install_plugin_success" BotMsgParserPluginNotEnabled Key = "bot.msg.parser.plugin_not_enabled" BotMsgParserPromptReplyWithParserFile Key = "bot.msg.parser.prompt_reply_with_parser_file" + BotMsgProgressAria2Done Key = "bot.msg.progress.aria2_done" + BotMsgProgressAria2Downloading Key = "bot.msg.progress.aria2_downloading" + BotMsgProgressAria2Start Key = "bot.msg.progress.aria2_start" BotMsgProgressAvgSpeedPrefix Key = "bot.msg.progress.avg_speed_prefix" BotMsgProgressBatchDonePrefix Key = "bot.msg.progress.batch_done_prefix" BotMsgProgressBatchProcessingPrefix Key = "bot.msg.progress.batch_processing_prefix" BotMsgProgressBatchStartPrefix Key = "bot.msg.progress.batch_start_prefix" BotMsgProgressCurrentProgressPrefix Key = "bot.msg.progress.current_progress_prefix" + BotMsgProgressCurrentSpeedPrefix Key = "bot.msg.progress.current_speed_prefix" BotMsgProgressDirectDonePrefix Key = "bot.msg.progress.direct_done_prefix" BotMsgProgressDirectStart Key = "bot.msg.progress.direct_start" BotMsgProgressDownloadDonePrefix Key = "bot.msg.progress.download_done_prefix" BotMsgProgressDownloadFailedPrefix Key = "bot.msg.progress.download_failed_prefix" + BotMsgProgressDownloadedPrefix Key = "bot.msg.progress.downloaded_prefix" BotMsgProgressDownloadingPrefix Key = "bot.msg.progress.downloading_prefix" BotMsgProgressErrorPrefix Key = "bot.msg.progress.error_prefix" BotMsgProgressFileNamePrefix Key = "bot.msg.progress.file_name_prefix" diff --git a/common/i18n/locale/en.yaml b/common/i18n/locale/en.yaml index 7d6fcf8..ffb15d5 100644 --- a/common/i18n/locale/en.yaml +++ b/common/i18n/locale/en.yaml @@ -326,7 +326,19 @@ bot: direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)" file_name_prefix: "Filename: " error_prefix: "\nError: " + aria2_start: "Waiting for Aria2 to complete download (GID: {{.GID}})..." + aria2_downloading: "Aria2 is downloading (GID: {{.GID}})\n" + aria2_done: "Aria2 download completed and saved (GID: {{.GID}})\n" + downloaded_prefix: "\nDownloaded: " + current_speed_prefix: "\nCurrent speed: " syncpeers: start: "Starting to sync peers..." done: "Peer sync completed, total {{.Count}} chats synced" failed: "Peer sync failed: {{.Error}}" + aria2: + error_aria2_not_enabled: "Aria2 feature is not enabled in the configuration" + error_aria2_client_init_failed: "Aria2 client initialization failed: {{.Error}}" + info_adding_aria2_download: "Adding Aria2 download task..." + error_adding_aria2_download: "Failed to add Aria2 download task: {{.Error}}" + info_aria2_download_added: "Aria2 download task added, GID: {{.GID}}" + info_select_storage: "Please select storage, the task will be added to Aria2 download queue after selection" diff --git a/common/i18n/locale/zh-Hans.yaml b/common/i18n/locale/zh-Hans.yaml index f1c121e..531ff78 100644 --- a/common/i18n/locale/zh-Hans.yaml +++ b/common/i18n/locale/zh-Hans.yaml @@ -328,6 +328,11 @@ bot: direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)" file_name_prefix: "文件名: " error_prefix: "\n错误: " + aria2_start: "等待 Aria2 下载完成 (GID: {{.GID}})..." + aria2_downloading: "Aria2 正在下载 (GID: {{.GID}})\n" + aria2_done: "Aria2 下载完成并已转存 (GID: {{.GID}})\n" + downloaded_prefix: "\n已下载: " + current_speed_prefix: "\n当前速度: " syncpeers: start: "正在同步对话列表..." success: "对话列表同步完成, 共同步 {{.Count}} 个对话" @@ -338,3 +343,4 @@ bot: info_adding_aria2_download: "正在添加 Aria2 下载任务..." error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}" info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}" + info_select_storage: "请选择存储位置, 选择后将添加到 Aria2 下载队列" \ No newline at end of file diff --git a/config.example.toml b/config.example.toml index a94325e..bf83165 100644 --- a/config.example.toml +++ b/config.example.toml @@ -18,6 +18,17 @@ token = "" enable = false url = "socks5://127.0.0.1:7890" +# Aria2 配置 +[aria2] +# 启用 Aria2 下载支持 +enable = false +# Aria2 RPC URL +url = "http://localhost:6800/jsonrpc" +# Aria2 RPC Secret (如果配置了 rpc-secret) +secret = "" +# 转存完成后删除 Aria2 下载的本地文件 +remove_after_transfer = true + # 存储列表 [[storages]] # 标识名, 需要唯一 diff --git a/config/viper.go b/config/viper.go index 5598322..09a32e9 100644 --- a/config/viper.go +++ b/config/viper.go @@ -36,9 +36,10 @@ type Config struct { } type aria2Config struct { - Enable bool `toml:"enable" mapstructure:"enable" json:"enable"` - Url string `toml:"url" mapstructure:"url" json:"url"` - Secret string `toml:"secret" mapstructure:"secret" json:"secret"` + Enable bool `toml:"enable" mapstructure:"enable" json:"enable"` + Url string `toml:"url" mapstructure:"url" json:"url"` + Secret string `toml:"secret" mapstructure:"secret" json:"secret"` + RemoveAfterTransfer bool `toml:"remove_after_transfer" mapstructure:"remove_after_transfer" json:"remove_after_transfer"` } var cfg = &Config{} diff --git a/core/tasks/aria2dl/execute.go b/core/tasks/aria2dl/execute.go new file mode 100644 index 0000000..88de2cc --- /dev/null +++ b/core/tasks/aria2dl/execute.go @@ -0,0 +1,208 @@ +package aria2dl + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/aria2" + "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" +) + +// Execute implements core.Executable. +func (t *Task) Execute(ctx context.Context) error { + logger := log.FromContext(ctx) + logger.Infof("Starting aria2 download task %s (GID: %s)", t.ID, t.GID) + + if t.Progress != nil { + t.Progress.OnStart(ctx, t) + } + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + var status *aria2.Status + var err error + + for { + select { + case <-ctx.Done(): + logger.Warn("Aria2 task canceled") + if t.Progress != nil { + t.Progress.OnDone(ctx, t, ctx.Err()) + } + return ctx.Err() + case <-ticker.C: + // Try to get status from active/waiting queue first + status, err = t.Aria2Client.TellStatus(ctx, t.GID) + if err != nil { + // If GID not found in active queue, check stopped queue + logger.Debugf("Task not in active queue, checking stopped queue: %v", err) + stoppedTasks, stopErr := t.Aria2Client.TellStopped(ctx, -1, 100) + if stopErr != nil { + logger.Errorf("Failed to get stopped tasks: %v", stopErr) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return fmt.Errorf("failed to get aria2 status: %w", err) + } + + // Find our task in stopped queue + found := false + for _, task := range stoppedTasks { + if task.GID == t.GID { + status = &task + found = true + logger.Debugf("Found task in stopped queue with status: %s", status.Status) + break + } + } + + if !found { + logger.Errorf("Task GID %s not found in active or stopped queue", t.GID) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return fmt.Errorf("aria2 task not found: %w", err) + } + } + + logger.Debugf("Aria2 GID %s status: %s, completed: %s/%s", + t.GID, status.Status, status.CompletedLength, status.TotalLength) + + if t.Progress != nil { + t.Progress.OnProgress(ctx, t, status) + } + + // Check if download is complete + if status.IsDownloadComplete() { + logger.Infof("Aria2 download completed for GID %s", t.GID) + goto TransferFiles + } + + // Check for errors + if status.IsDownloadError() { + err := fmt.Errorf("aria2 download error: %s (code: %s)", status.ErrorMessage, status.ErrorCode) + logger.Errorf("Aria2 download failed: %v", err) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return err + } + + // Check if removed + if status.IsDownloadRemoved() { + err := errors.New("aria2 download was removed") + logger.Error("Aria2 download was removed") + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return err + } + } + } + +TransferFiles: + // Get final status to get file list + status, err = t.Aria2Client.TellStatus(ctx, t.GID) + if err != nil { + logger.Errorf("Failed to get final status: %v", err) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return fmt.Errorf("failed to get final status: %w", err) + } + + if len(status.Files) == 0 { + err := errors.New("no files in aria2 download") + logger.Error("No files in aria2 download") + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return err + } + + // Transfer files to storage + logger.Infof("Transferring %d file(s) to storage %s", len(status.Files), t.Storage.Name()) + for _, file := range status.Files { + if file.Selected != "true" { + logger.Debugf("Skipping unselected file: %s", file.Path) + continue + } + + // Check if file exists + if _, err := os.Stat(file.Path); os.IsNotExist(err) { + logger.Errorf("Downloaded file not found: %s", file.Path) + continue + } + + // Open file + f, err := os.Open(file.Path) + if err != nil { + logger.Errorf("Failed to open file %s: %v", file.Path, err) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return fmt.Errorf("failed to open file %s: %w", file.Path, err) + } + defer f.Close() + + // Get file info + fileInfo, err := f.Stat() + if err != nil { + logger.Errorf("Failed to stat file %s: %v", file.Path, err) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return fmt.Errorf("failed to stat file %s: %w", file.Path, err) + } + + // Set content length in context for storage + ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size()) + + // Determine destination path + fileName := filepath.Base(file.Path) + destPath := filepath.Join(t.StorPath, fileName) + + logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath) + + // Save to storage + err = t.Storage.Save(ctx, f, destPath) + if err != nil { + logger.Errorf("Failed to save file %s to storage: %v", fileName, err) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return fmt.Errorf("failed to save file %s to storage: %w", fileName, err) + } + + logger.Infof("Successfully transferred file %s", fileName) + + // Optionally remove the local file after successful transfer + if config.C().Aria2.RemoveAfterTransfer { + if err := os.Remove(file.Path); err != nil { + logger.Warnf("Failed to remove local file %s: %v", file.Path, err) + } else { + logger.Debugf("Removed local file %s", file.Path) + } + } + } + + logger.Infof("Aria2 task %s completed successfully", t.ID) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, nil) + } + + // Clean up aria2 download result + _, err = t.Aria2Client.RemoveDownloadResult(ctx, t.GID) + if err != nil { + logger.Warnf("Failed to remove aria2 download result: %v", err) + } + + return nil +} diff --git a/core/tasks/aria2dl/progress.go b/core/tasks/aria2dl/progress.go new file mode 100644 index 0000000..89079d7 --- /dev/null +++ b/core/tasks/aria2dl/progress.go @@ -0,0 +1,189 @@ +package aria2dl + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync/atomic" + "time" + + "github.com/charmbracelet/log" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/common/i18n/i18nk" + "github.com/krau/SaveAny-Bot/common/utils/dlutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/pkg/aria2" +) + +type ProgressTracker interface { + OnStart(ctx context.Context, task *Task) + OnProgress(ctx context.Context, task *Task, status *aria2.Status) + OnDone(ctx context.Context, task *Task, err error) +} + +type Progress struct { + msgID int + chatID int64 + start time.Time + lastUpdatePercent atomic.Int32 +} + +// OnStart implements ProgressTracker. +func (p *Progress) OnStart(ctx context.Context, task *Task) { + logger := log.FromContext(ctx) + p.start = time.Now() + p.lastUpdatePercent.Store(0) + logger.Infof("Aria2 task started: message_id=%d, chat_id=%d, gid=%s", p.msgID, p.chatID, task.GID) + ext := tgutil.ExtFromContext(ctx) + if ext == nil { + return + } + entityBuilder := entity.Builder{} + if err := styling.Perform(&entityBuilder, + styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Start, map[string]any{ + "GID": task.GID, + }))); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.msgID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(task.TaskID()), + }, + }, + }}, + ) + ext.EditMessage(p.chatID, req) +} + +// OnProgress implements ProgressTracker. +func (p *Progress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) { + totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64) + completedLength, _ := strconv.ParseInt(status.CompletedLength, 10, 64) + downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64) + + if totalLength == 0 { + return + } + + percent := int((completedLength * 100) / totalLength) + if p.lastUpdatePercent.Load() == int32(percent) { + return + } + p.lastUpdatePercent.Store(int32(percent)) + + log.FromContext(ctx).Debugf("Aria2 progress update: %s, %d/%d", task.GID, completedLength, totalLength) + + entityBuilder := entity.Builder{} + if err := styling.Perform(&entityBuilder, + styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Downloading, map[string]any{ + "GID": task.GID, + })), + styling.Plain(i18n.T(i18nk.BotMsgProgressDownloadedPrefix, nil)), + styling.Code(fmt.Sprintf("%.2f MB / %.2f MB", float64(completedLength)/(1024*1024), float64(totalLength)/(1024*1024))), + styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentSpeedPrefix, nil)), + styling.Bold(fmt.Sprintf("%.2f MB/s", float64(downloadSpeed)/(1024*1024))), + styling.Plain(i18n.T(i18nk.BotMsgProgressAvgSpeedPrefix, nil)), + styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(completedLength, p.start)/(1024*1024))), + styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentProgressPrefix, nil)), + styling.Bold(fmt.Sprintf("%.2f%%", float64(percent))), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.msgID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(task.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, req) + } +} + +// OnDone implements ProgressTracker. +func (p *Progress) OnDone(ctx context.Context, task *Task, err error) { + logger := log.FromContext(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + logger.Infof("Aria2 task %s was canceled", task.TaskID()) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{ + ID: p.msgID, + Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{ + "TaskID": task.TaskID(), + }), + }) + } + } else { + logger.Errorf("Aria2 task %s failed: %s", task.TaskID(), err) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{ + ID: p.msgID, + Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{ + "Error": err.Error(), + }), + }) + } + } + return + } + logger.Infof("Aria2 task %s completed successfully", task.TaskID()) + + entityBuilder := entity.Builder{} + if err := styling.Perform(&entityBuilder, + styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Done, map[string]any{ + "GID": task.GID, + })), + styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)), + styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)), + ); err != nil { + logger.Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.msgID, + } + req.SetMessage(text) + req.SetEntities(entities) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, req) + } +} + +var _ ProgressTracker = (*Progress)(nil) + +func NewProgress(msgID int, userID int64) ProgressTracker { + return &Progress{ + msgID: msgID, + chatID: userID, + } +} diff --git a/core/tasks/aria2dl/task.go b/core/tasks/aria2dl/task.go new file mode 100644 index 0000000..0e09be1 --- /dev/null +++ b/core/tasks/aria2dl/task.go @@ -0,0 +1,61 @@ +package aria2dl + +import ( + "context" + "fmt" + + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/pkg/aria2" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/storage" +) + +var _ core.Executable = (*Task)(nil) + +type Task struct { + ID string + ctx context.Context + GID string + URIs []string + Aria2Client *aria2.Client + Storage storage.Storage + StorPath string + Progress ProgressTracker +} + +// Title implements core.Executable. +func (t *Task) Title() string { + return fmt.Sprintf("[%s](Aria2 GID:%s->%s:%s)", t.Type(), t.GID, t.Storage.Name(), t.StorPath) +} + +// Type implements core.Executable. +func (t *Task) Type() tasktype.TaskType { + return tasktype.TaskTypeAria2 +} + +// TaskID implements core.Executable. +func (t *Task) TaskID() string { + return t.ID +} + +func NewTask( + id string, + ctx context.Context, + gid string, + uris []string, + aria2Client *aria2.Client, + stor storage.Storage, + storPath string, + progressTracker ProgressTracker, +) *Task { + return &Task{ + ID: id, + ctx: ctx, + GID: gid, + URIs: uris, + Aria2Client: aria2Client, + Storage: stor, + StorPath: storPath, + Progress: progressTracker, + } +} diff --git a/core/tasks/aria2dl/task_test.go b/core/tasks/aria2dl/task_test.go new file mode 100644 index 0000000..4e6ba65 --- /dev/null +++ b/core/tasks/aria2dl/task_test.go @@ -0,0 +1,209 @@ +package aria2dl + +import ( + "context" + "io" + "testing" + "time" + + storconfig "github.com/krau/SaveAny-Bot/config/storage" + "github.com/krau/SaveAny-Bot/pkg/aria2" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" +) + +type mockStorage struct { + name string + savePath string +} + +func (m *mockStorage) Name() string { + return m.name +} + +func (m *mockStorage) Type() storenum.StorageType { + return storenum.StorageType("mock") +} + +func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig) error { + return nil +} + +func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error { + m.savePath = path + return nil +} + +func (m *mockStorage) Exists(ctx context.Context, path string) bool { + return false +} + +func (m *mockStorage) JoinStoragePath(path string) string { + return path +} + +type mockProgress struct { + started bool + done bool + doneErr error + progress int +} + +func (m *mockProgress) OnStart(ctx context.Context, task *Task) { + m.started = true +} + +func (m *mockProgress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) { + m.progress++ +} + +func (m *mockProgress) OnDone(ctx context.Context, task *Task, err error) { + m.done = true + m.doneErr = err +} + +func TestTaskCreation(t *testing.T) { + ctx := context.Background() + mockStor := &mockStorage{name: "test-storage"} + mockProg := &mockProgress{} + + task := NewTask( + "test-task-id", + ctx, + "test-gid", + []string{"http://example.com/file.zip"}, + nil, + mockStor, + "/test/path", + mockProg, + ) + + if task.ID != "test-task-id" { + t.Errorf("Expected task ID to be 'test-task-id', got '%s'", task.ID) + } + + if task.GID != "test-gid" { + t.Errorf("Expected GID to be 'test-gid', got '%s'", task.GID) + } + + if task.Type() != tasktype.TaskTypeAria2 { + t.Errorf("Expected task type to be TaskTypeAria2, got '%s'", task.Type()) + } + + if task.TaskID() != "test-task-id" { + t.Errorf("Expected TaskID() to return 'test-task-id', got '%s'", task.TaskID()) + } + + if task.Storage.Name() != "test-storage" { + t.Errorf("Expected storage name to be 'test-storage', got '%s'", task.Storage.Name()) + } +} + +func TestProgressTracker(t *testing.T) { + ctx := context.Background() + mockStor := &mockStorage{name: "test-storage"} + mockProg := &mockProgress{} + + task := NewTask( + "test-task-id", + ctx, + "test-gid", + []string{"http://example.com/file.zip"}, + nil, + mockStor, + "/test/path", + mockProg, + ) + + // Test OnStart + mockProg.OnStart(ctx, task) + if !mockProg.started { + t.Error("Expected OnStart to set started to true") + } + + // Test OnProgress + status := &aria2.Status{ + GID: "test-gid", + Status: "active", + TotalLength: "1000000", + CompletedLength: "500000", + DownloadSpeed: "100000", + } + mockProg.OnProgress(ctx, task, status) + if mockProg.progress != 1 { + t.Errorf("Expected progress to be 1, got %d", mockProg.progress) + } + + // Test OnDone + mockProg.OnDone(ctx, task, nil) + if !mockProg.done { + t.Error("Expected OnDone to set done to true") + } + if mockProg.doneErr != nil { + t.Errorf("Expected doneErr to be nil, got %v", mockProg.doneErr) + } +} + +func TestTaskTitle(t *testing.T) { + ctx := context.Background() + mockStor := &mockStorage{name: "test-storage"} + + task := NewTask( + "test-task-id", + ctx, + "test-gid-123", + []string{"http://example.com/file.zip"}, + nil, + mockStor, + "/test/path", + nil, + ) + + title := task.Title() + expectedSubstr := "test-gid-123" + if len(title) == 0 { + t.Error("Expected title to not be empty") + } + + // Check if title contains the GID + found := false + for i := 0; i < len(title)-len(expectedSubstr)+1; i++ { + if title[i:i+len(expectedSubstr)] == expectedSubstr { + found = true + break + } + } + if !found { + t.Errorf("Expected title to contain GID '%s', got '%s'", expectedSubstr, title) + } +} + +func TestContextCancellation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + mockStor := &mockStorage{name: "test-storage"} + mockProg := &mockProgress{} + + task := NewTask( + "test-task-id", + ctx, + "test-gid", + []string{"http://example.com/file.zip"}, + nil, // nil client will cause Execute to fail/timeout + mockStor, + "/test/path", + mockProg, + ) + + // Just verify the task structure is valid + if task.ctx.Err() != nil { + t.Error("Context should not be cancelled yet") + } + + // Wait for context to timeout + <-ctx.Done() + if ctx.Err() == nil { + t.Error("Context should be cancelled after timeout") + } +} diff --git a/pkg/enums/tasktype/tasktype.go b/pkg/enums/tasktype/tasktype.go index 454a08a..b41b2ed 100644 --- a/pkg/enums/tasktype/tasktype.go +++ b/pkg/enums/tasktype/tasktype.go @@ -1,5 +1,5 @@ package tasktype //go:generate go-enum --values --names --flag --nocase -// ENUM(tgfiles,tphpics,parseditem,directlinks) +// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2) type TaskType string diff --git a/pkg/enums/tasktype/tasktype_enum.go b/pkg/enums/tasktype/tasktype_enum.go index df0f5ae..940e269 100644 --- a/pkg/enums/tasktype/tasktype_enum.go +++ b/pkg/enums/tasktype/tasktype_enum.go @@ -20,6 +20,8 @@ const ( TaskTypeParseditem TaskType = "parseditem" // TaskTypeDirectlinks is a TaskType of type directlinks. TaskTypeDirectlinks TaskType = "directlinks" + // TaskTypeAria2 is a TaskType of type aria2. + TaskTypeAria2 TaskType = "aria2" ) var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", ")) @@ -29,6 +31,7 @@ var _TaskTypeNames = []string{ string(TaskTypeTphpics), string(TaskTypeParseditem), string(TaskTypeDirectlinks), + string(TaskTypeAria2), } // TaskTypeNames returns a list of possible string values of TaskType. @@ -45,6 +48,7 @@ func TaskTypeValues() []TaskType { TaskTypeTphpics, TaskTypeParseditem, TaskTypeDirectlinks, + TaskTypeAria2, } } @@ -65,6 +69,7 @@ var _TaskTypeValue = map[string]TaskType{ "tphpics": TaskTypeTphpics, "parseditem": TaskTypeParseditem, "directlinks": TaskTypeDirectlinks, + "aria2": TaskTypeAria2, } // ParseTaskType attempts to convert a string to a TaskType. diff --git a/pkg/tcbdata/data.go b/pkg/tcbdata/data.go index 84691b8..fcccf16 100644 --- a/pkg/tcbdata/data.go +++ b/pkg/tcbdata/data.go @@ -45,6 +45,8 @@ type Add struct { ParsedItem *parser.Item // directlinks DirectLinks []string + // aria2 + Aria2URIs []string } type SetDefaultStorage struct {