feat: basic aria2 integration

This commit is contained in:
krau
2026-01-17 14:57:03 +08:00
parent f17a380579
commit b05d86509c
16 changed files with 809 additions and 15 deletions

View File

@@ -90,6 +90,15 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error {
shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID) shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID)
case tasktype.TaskTypeDirectlinks: case tasktype.TaskTypeDirectlinks:
shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID) shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID)
case tasktype.TaskTypeAria2:
client := GetAria2Client()
if client == nil {
ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, i18n.T(i18nk.BotMsgAria2ErrorAria2ClientInitFailed, map[string]any{
"Error": "aria2 client not initialized",
})))
return dispatcher.EndGroups
}
shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID)
default: default:
return fmt.Errorf("unexcept task type: %s", data.TaskType) return fmt.Errorf("unexcept task type: %s", data.TaskType)
} }

View File

@@ -58,6 +58,11 @@ var aria2ClientInitOnce sync.Once
var aria2ClientInitErr error var aria2ClientInitErr error
var aria2Client *aria2.Client var aria2Client *aria2.Client
// GetAria2Client returns the shared aria2 client instance
func GetAria2Client() *aria2.Client {
return aria2Client
}
func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error { func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
if !config.C().Aria2.Enable { if !config.C().Aria2.Enable {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil) ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil)
@@ -78,7 +83,9 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil) ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil)
return nil return nil
} }
logger.Debug("Adding aria2 download", "links", links) logger.Debug("Preparing aria2 download", "links", links)
// Initialize aria2 client to check connection
aria2ClientInitOnce.Do(func() { aria2ClientInitOnce.Do(func() {
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret) aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)
}) })
@@ -89,17 +96,18 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
})), nil) })), nil)
return nil return nil
} }
gid, err := aria2Client.AddURI(ctx, links, nil)
// Build storage selection keyboard (don't add to aria2 yet)
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
TaskType: tasktype.TaskTypeAria2,
Aria2URIs: links,
})
if err != nil { if err != nil {
logger.Error("Failed to add aria2 download", "error", err) return err
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
"Error": err.Error(),
})), nil)
return nil
} }
logger.Info("Aria2 download added", "gid", gid)
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoAria2DownloadAdded, map[string]any{ ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoSelectStorage)), &ext.ReplyOpts{
"GID": gid, Markup: markup,
})), nil) })
return nil return nil
} }

View File

@@ -49,6 +49,8 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add)
ParsedItem: adddata.ParsedItem, ParsedItem: adddata.ParsedItem,
DirectLinks: adddata.DirectLinks, DirectLinks: adddata.DirectLinks,
Aria2URIs: adddata.Aria2URIs,
} }
dataid := xid.New().String() dataid := xid.New().String()
err := cache.Set(dataid, data) err := cache.Set(dataid, data)

View File

@@ -0,0 +1,65 @@
package shortcut
import (
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
func CreateAndAddAria2TaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, uris []string, aria2Client *aria2.Client, msgID int, userID int64) error {
logger := log.FromContext(ctx)
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
// Now add to aria2 after user selected storage
logger.Infof("Adding download to aria2, uris type: %T, value: %+v", uris, uris)
// Ensure uris is valid
if len(uris) == 0 {
logger.Error("URIs list is empty")
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgDlErrorNoValidLinks, nil),
})
return dispatcher.EndGroups
}
gid, err := aria2Client.AddURI(ctx, uris, nil)
if err != nil {
logger.Errorf("Failed to add aria2 download: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
"Error": err.Error(),
}),
})
return dispatcher.EndGroups
}
logger.Infof("Aria2 download added with GID: %s", gid)
// Create task with the GID
task := aria2dl.NewTask(xid.New().String(), injectCtx, gid, uris, aria2Client, stor, stor.JoinStoragePath(dirPath), aria2dl.NewProgress(msgID, userID))
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),
})
return dispatcher.EndGroups
}
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonInfoTaskAdded, nil),
})
return dispatcher.EndGroups
}

View File

@@ -9,6 +9,7 @@ const (
BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled" BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled"
BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download" BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download"
BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added" BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added"
BotMsgAria2InfoSelectStorage Key = "bot.msg.aria2.info_select_storage"
BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed" BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed"
BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested" BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested"
BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task" BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task"
@@ -127,15 +128,20 @@ const (
BotMsgParserInfoInstallPluginSuccess Key = "bot.msg.parser.info_install_plugin_success" BotMsgParserInfoInstallPluginSuccess Key = "bot.msg.parser.info_install_plugin_success"
BotMsgParserPluginNotEnabled Key = "bot.msg.parser.plugin_not_enabled" BotMsgParserPluginNotEnabled Key = "bot.msg.parser.plugin_not_enabled"
BotMsgParserPromptReplyWithParserFile Key = "bot.msg.parser.prompt_reply_with_parser_file" BotMsgParserPromptReplyWithParserFile Key = "bot.msg.parser.prompt_reply_with_parser_file"
BotMsgProgressAria2Done Key = "bot.msg.progress.aria2_done"
BotMsgProgressAria2Downloading Key = "bot.msg.progress.aria2_downloading"
BotMsgProgressAria2Start Key = "bot.msg.progress.aria2_start"
BotMsgProgressAvgSpeedPrefix Key = "bot.msg.progress.avg_speed_prefix" BotMsgProgressAvgSpeedPrefix Key = "bot.msg.progress.avg_speed_prefix"
BotMsgProgressBatchDonePrefix Key = "bot.msg.progress.batch_done_prefix" BotMsgProgressBatchDonePrefix Key = "bot.msg.progress.batch_done_prefix"
BotMsgProgressBatchProcessingPrefix Key = "bot.msg.progress.batch_processing_prefix" BotMsgProgressBatchProcessingPrefix Key = "bot.msg.progress.batch_processing_prefix"
BotMsgProgressBatchStartPrefix Key = "bot.msg.progress.batch_start_prefix" BotMsgProgressBatchStartPrefix Key = "bot.msg.progress.batch_start_prefix"
BotMsgProgressCurrentProgressPrefix Key = "bot.msg.progress.current_progress_prefix" BotMsgProgressCurrentProgressPrefix Key = "bot.msg.progress.current_progress_prefix"
BotMsgProgressCurrentSpeedPrefix Key = "bot.msg.progress.current_speed_prefix"
BotMsgProgressDirectDonePrefix Key = "bot.msg.progress.direct_done_prefix" BotMsgProgressDirectDonePrefix Key = "bot.msg.progress.direct_done_prefix"
BotMsgProgressDirectStart Key = "bot.msg.progress.direct_start" BotMsgProgressDirectStart Key = "bot.msg.progress.direct_start"
BotMsgProgressDownloadDonePrefix Key = "bot.msg.progress.download_done_prefix" BotMsgProgressDownloadDonePrefix Key = "bot.msg.progress.download_done_prefix"
BotMsgProgressDownloadFailedPrefix Key = "bot.msg.progress.download_failed_prefix" BotMsgProgressDownloadFailedPrefix Key = "bot.msg.progress.download_failed_prefix"
BotMsgProgressDownloadedPrefix Key = "bot.msg.progress.downloaded_prefix"
BotMsgProgressDownloadingPrefix Key = "bot.msg.progress.downloading_prefix" BotMsgProgressDownloadingPrefix Key = "bot.msg.progress.downloading_prefix"
BotMsgProgressErrorPrefix Key = "bot.msg.progress.error_prefix" BotMsgProgressErrorPrefix Key = "bot.msg.progress.error_prefix"
BotMsgProgressFileNamePrefix Key = "bot.msg.progress.file_name_prefix" BotMsgProgressFileNamePrefix Key = "bot.msg.progress.file_name_prefix"

View File

@@ -326,7 +326,19 @@ bot:
direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)" direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)"
file_name_prefix: "Filename: " file_name_prefix: "Filename: "
error_prefix: "\nError: " error_prefix: "\nError: "
aria2_start: "Waiting for Aria2 to complete download (GID: {{.GID}})..."
aria2_downloading: "Aria2 is downloading (GID: {{.GID}})\n"
aria2_done: "Aria2 download completed and saved (GID: {{.GID}})\n"
downloaded_prefix: "\nDownloaded: "
current_speed_prefix: "\nCurrent speed: "
syncpeers: syncpeers:
start: "Starting to sync peers..." start: "Starting to sync peers..."
done: "Peer sync completed, total {{.Count}} chats synced" done: "Peer sync completed, total {{.Count}} chats synced"
failed: "Peer sync failed: {{.Error}}" failed: "Peer sync failed: {{.Error}}"
aria2:
error_aria2_not_enabled: "Aria2 feature is not enabled in the configuration"
error_aria2_client_init_failed: "Aria2 client initialization failed: {{.Error}}"
info_adding_aria2_download: "Adding Aria2 download task..."
error_adding_aria2_download: "Failed to add Aria2 download task: {{.Error}}"
info_aria2_download_added: "Aria2 download task added, GID: {{.GID}}"
info_select_storage: "Please select storage, the task will be added to Aria2 download queue after selection"

View File

@@ -328,6 +328,11 @@ bot:
direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)" direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)"
file_name_prefix: "文件名: " file_name_prefix: "文件名: "
error_prefix: "\n错误: " error_prefix: "\n错误: "
aria2_start: "等待 Aria2 下载完成 (GID: {{.GID}})..."
aria2_downloading: "Aria2 正在下载 (GID: {{.GID}})\n"
aria2_done: "Aria2 下载完成并已转存 (GID: {{.GID}})\n"
downloaded_prefix: "\n已下载: "
current_speed_prefix: "\n当前速度: "
syncpeers: syncpeers:
start: "正在同步对话列表..." start: "正在同步对话列表..."
success: "对话列表同步完成, 共同步 {{.Count}} 个对话" success: "对话列表同步完成, 共同步 {{.Count}} 个对话"
@@ -338,3 +343,4 @@ bot:
info_adding_aria2_download: "正在添加 Aria2 下载任务..." info_adding_aria2_download: "正在添加 Aria2 下载任务..."
error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}" error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}"
info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}" info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}"
info_select_storage: "请选择存储位置, 选择后将添加到 Aria2 下载队列"

View File

@@ -18,6 +18,17 @@ token = ""
enable = false enable = false
url = "socks5://127.0.0.1:7890" url = "socks5://127.0.0.1:7890"
# Aria2 配置
[aria2]
# 启用 Aria2 下载支持
enable = false
# Aria2 RPC URL
url = "http://localhost:6800/jsonrpc"
# Aria2 RPC Secret (如果配置了 rpc-secret)
secret = ""
# 转存完成后删除 Aria2 下载的本地文件
remove_after_transfer = true
# 存储列表 # 存储列表
[[storages]] [[storages]]
# 标识名, 需要唯一 # 标识名, 需要唯一

View File

@@ -36,9 +36,10 @@ type Config struct {
} }
type aria2Config struct { type aria2Config struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"` Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
Url string `toml:"url" mapstructure:"url" json:"url"` Url string `toml:"url" mapstructure:"url" json:"url"`
Secret string `toml:"secret" mapstructure:"secret" json:"secret"` Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
RemoveAfterTransfer bool `toml:"remove_after_transfer" mapstructure:"remove_after_transfer" json:"remove_after_transfer"`
} }
var cfg = &Config{} var cfg = &Config{}

View File

@@ -0,0 +1,208 @@
package aria2dl
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
)
// Execute implements core.Executable.
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx)
logger.Infof("Starting aria2 download task %s (GID: %s)", t.ID, t.GID)
if t.Progress != nil {
t.Progress.OnStart(ctx, t)
}
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
var status *aria2.Status
var err error
for {
select {
case <-ctx.Done():
logger.Warn("Aria2 task canceled")
if t.Progress != nil {
t.Progress.OnDone(ctx, t, ctx.Err())
}
return ctx.Err()
case <-ticker.C:
// Try to get status from active/waiting queue first
status, err = t.Aria2Client.TellStatus(ctx, t.GID)
if err != nil {
// If GID not found in active queue, check stopped queue
logger.Debugf("Task not in active queue, checking stopped queue: %v", err)
stoppedTasks, stopErr := t.Aria2Client.TellStopped(ctx, -1, 100)
if stopErr != nil {
logger.Errorf("Failed to get stopped tasks: %v", stopErr)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("failed to get aria2 status: %w", err)
}
// Find our task in stopped queue
found := false
for _, task := range stoppedTasks {
if task.GID == t.GID {
status = &task
found = true
logger.Debugf("Found task in stopped queue with status: %s", status.Status)
break
}
}
if !found {
logger.Errorf("Task GID %s not found in active or stopped queue", t.GID)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("aria2 task not found: %w", err)
}
}
logger.Debugf("Aria2 GID %s status: %s, completed: %s/%s",
t.GID, status.Status, status.CompletedLength, status.TotalLength)
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, status)
}
// Check if download is complete
if status.IsDownloadComplete() {
logger.Infof("Aria2 download completed for GID %s", t.GID)
goto TransferFiles
}
// Check for errors
if status.IsDownloadError() {
err := fmt.Errorf("aria2 download error: %s (code: %s)", status.ErrorMessage, status.ErrorCode)
logger.Errorf("Aria2 download failed: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
// Check if removed
if status.IsDownloadRemoved() {
err := errors.New("aria2 download was removed")
logger.Error("Aria2 download was removed")
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
}
}
TransferFiles:
// Get final status to get file list
status, err = t.Aria2Client.TellStatus(ctx, t.GID)
if err != nil {
logger.Errorf("Failed to get final status: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("failed to get final status: %w", err)
}
if len(status.Files) == 0 {
err := errors.New("no files in aria2 download")
logger.Error("No files in aria2 download")
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
// Transfer files to storage
logger.Infof("Transferring %d file(s) to storage %s", len(status.Files), t.Storage.Name())
for _, file := range status.Files {
if file.Selected != "true" {
logger.Debugf("Skipping unselected file: %s", file.Path)
continue
}
// Check if file exists
if _, err := os.Stat(file.Path); os.IsNotExist(err) {
logger.Errorf("Downloaded file not found: %s", file.Path)
continue
}
// Open file
f, err := os.Open(file.Path)
if err != nil {
logger.Errorf("Failed to open file %s: %v", file.Path, err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("failed to open file %s: %w", file.Path, err)
}
defer f.Close()
// Get file info
fileInfo, err := f.Stat()
if err != nil {
logger.Errorf("Failed to stat file %s: %v", file.Path, err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("failed to stat file %s: %w", file.Path, err)
}
// Set content length in context for storage
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
// Determine destination path
fileName := filepath.Base(file.Path)
destPath := filepath.Join(t.StorPath, fileName)
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
// Save to storage
err = t.Storage.Save(ctx, f, destPath)
if err != nil {
logger.Errorf("Failed to save file %s to storage: %v", fileName, err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
}
logger.Infof("Successfully transferred file %s", fileName)
// Optionally remove the local file after successful transfer
if config.C().Aria2.RemoveAfterTransfer {
if err := os.Remove(file.Path); err != nil {
logger.Warnf("Failed to remove local file %s: %v", file.Path, err)
} else {
logger.Debugf("Removed local file %s", file.Path)
}
}
}
logger.Infof("Aria2 task %s completed successfully", t.ID)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, nil)
}
// Clean up aria2 download result
_, err = t.Aria2Client.RemoveDownloadResult(ctx, t.GID)
if err != nil {
logger.Warnf("Failed to remove aria2 download result: %v", err)
}
return nil
}

View File

@@ -0,0 +1,189 @@
package aria2dl
import (
"context"
"errors"
"fmt"
"strconv"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/pkg/aria2"
)
type ProgressTracker interface {
OnStart(ctx context.Context, task *Task)
OnProgress(ctx context.Context, task *Task, status *aria2.Status)
OnDone(ctx context.Context, task *Task, err error)
}
type Progress struct {
msgID int
chatID int64
start time.Time
lastUpdatePercent atomic.Int32
}
// OnStart implements ProgressTracker.
func (p *Progress) OnStart(ctx context.Context, task *Task) {
logger := log.FromContext(ctx)
p.start = time.Now()
p.lastUpdatePercent.Store(0)
logger.Infof("Aria2 task started: message_id=%d, chat_id=%d, gid=%s", p.msgID, p.chatID, task.GID)
ext := tgutil.ExtFromContext(ctx)
if ext == nil {
return
}
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Start, map[string]any{
"GID": task.GID,
}))); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext.EditMessage(p.chatID, req)
}
// OnProgress implements ProgressTracker.
func (p *Progress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64)
completedLength, _ := strconv.ParseInt(status.CompletedLength, 10, 64)
downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64)
if totalLength == 0 {
return
}
percent := int((completedLength * 100) / totalLength)
if p.lastUpdatePercent.Load() == int32(percent) {
return
}
p.lastUpdatePercent.Store(int32(percent))
log.FromContext(ctx).Debugf("Aria2 progress update: %s, %d/%d", task.GID, completedLength, totalLength)
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Downloading, map[string]any{
"GID": task.GID,
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressDownloadedPrefix, nil)),
styling.Code(fmt.Sprintf("%.2f MB / %.2f MB", float64(completedLength)/(1024*1024), float64(totalLength)/(1024*1024))),
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentSpeedPrefix, nil)),
styling.Bold(fmt.Sprintf("%.2f MB/s", float64(downloadSpeed)/(1024*1024))),
styling.Plain(i18n.T(i18nk.BotMsgProgressAvgSpeedPrefix, nil)),
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(completedLength, p.start)/(1024*1024))),
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentProgressPrefix, nil)),
styling.Bold(fmt.Sprintf("%.2f%%", float64(percent))),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
// OnDone implements ProgressTracker.
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
logger := log.FromContext(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
logger.Infof("Aria2 task %s was canceled", task.TaskID())
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
"TaskID": task.TaskID(),
}),
})
}
} else {
logger.Errorf("Aria2 task %s failed: %s", task.TaskID(), err)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
"Error": err.Error(),
}),
})
}
}
return
}
logger.Infof("Aria2 task %s completed successfully", task.TaskID())
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Done, map[string]any{
"GID": task.GID,
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
); err != nil {
logger.Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
var _ ProgressTracker = (*Progress)(nil)
func NewProgress(msgID int, userID int64) ProgressTracker {
return &Progress{
msgID: msgID,
chatID: userID,
}
}

View File

@@ -0,0 +1,61 @@
package aria2dl
import (
"context"
"fmt"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/storage"
)
var _ core.Executable = (*Task)(nil)
type Task struct {
ID string
ctx context.Context
GID string
URIs []string
Aria2Client *aria2.Client
Storage storage.Storage
StorPath string
Progress ProgressTracker
}
// Title implements core.Executable.
func (t *Task) Title() string {
return fmt.Sprintf("[%s](Aria2 GID:%s->%s:%s)", t.Type(), t.GID, t.Storage.Name(), t.StorPath)
}
// Type implements core.Executable.
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeAria2
}
// TaskID implements core.Executable.
func (t *Task) TaskID() string {
return t.ID
}
func NewTask(
id string,
ctx context.Context,
gid string,
uris []string,
aria2Client *aria2.Client,
stor storage.Storage,
storPath string,
progressTracker ProgressTracker,
) *Task {
return &Task{
ID: id,
ctx: ctx,
GID: gid,
URIs: uris,
Aria2Client: aria2Client,
Storage: stor,
StorPath: storPath,
Progress: progressTracker,
}
}

View File

@@ -0,0 +1,209 @@
package aria2dl
import (
"context"
"io"
"testing"
"time"
storconfig "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/pkg/aria2"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
type mockStorage struct {
name string
savePath string
}
func (m *mockStorage) Name() string {
return m.name
}
func (m *mockStorage) Type() storenum.StorageType {
return storenum.StorageType("mock")
}
func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig) error {
return nil
}
func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error {
m.savePath = path
return nil
}
func (m *mockStorage) Exists(ctx context.Context, path string) bool {
return false
}
func (m *mockStorage) JoinStoragePath(path string) string {
return path
}
type mockProgress struct {
started bool
done bool
doneErr error
progress int
}
func (m *mockProgress) OnStart(ctx context.Context, task *Task) {
m.started = true
}
func (m *mockProgress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
m.progress++
}
func (m *mockProgress) OnDone(ctx context.Context, task *Task, err error) {
m.done = true
m.doneErr = err
}
func TestTaskCreation(t *testing.T) {
ctx := context.Background()
mockStor := &mockStorage{name: "test-storage"}
mockProg := &mockProgress{}
task := NewTask(
"test-task-id",
ctx,
"test-gid",
[]string{"http://example.com/file.zip"},
nil,
mockStor,
"/test/path",
mockProg,
)
if task.ID != "test-task-id" {
t.Errorf("Expected task ID to be 'test-task-id', got '%s'", task.ID)
}
if task.GID != "test-gid" {
t.Errorf("Expected GID to be 'test-gid', got '%s'", task.GID)
}
if task.Type() != tasktype.TaskTypeAria2 {
t.Errorf("Expected task type to be TaskTypeAria2, got '%s'", task.Type())
}
if task.TaskID() != "test-task-id" {
t.Errorf("Expected TaskID() to return 'test-task-id', got '%s'", task.TaskID())
}
if task.Storage.Name() != "test-storage" {
t.Errorf("Expected storage name to be 'test-storage', got '%s'", task.Storage.Name())
}
}
func TestProgressTracker(t *testing.T) {
ctx := context.Background()
mockStor := &mockStorage{name: "test-storage"}
mockProg := &mockProgress{}
task := NewTask(
"test-task-id",
ctx,
"test-gid",
[]string{"http://example.com/file.zip"},
nil,
mockStor,
"/test/path",
mockProg,
)
// Test OnStart
mockProg.OnStart(ctx, task)
if !mockProg.started {
t.Error("Expected OnStart to set started to true")
}
// Test OnProgress
status := &aria2.Status{
GID: "test-gid",
Status: "active",
TotalLength: "1000000",
CompletedLength: "500000",
DownloadSpeed: "100000",
}
mockProg.OnProgress(ctx, task, status)
if mockProg.progress != 1 {
t.Errorf("Expected progress to be 1, got %d", mockProg.progress)
}
// Test OnDone
mockProg.OnDone(ctx, task, nil)
if !mockProg.done {
t.Error("Expected OnDone to set done to true")
}
if mockProg.doneErr != nil {
t.Errorf("Expected doneErr to be nil, got %v", mockProg.doneErr)
}
}
func TestTaskTitle(t *testing.T) {
ctx := context.Background()
mockStor := &mockStorage{name: "test-storage"}
task := NewTask(
"test-task-id",
ctx,
"test-gid-123",
[]string{"http://example.com/file.zip"},
nil,
mockStor,
"/test/path",
nil,
)
title := task.Title()
expectedSubstr := "test-gid-123"
if len(title) == 0 {
t.Error("Expected title to not be empty")
}
// Check if title contains the GID
found := false
for i := 0; i < len(title)-len(expectedSubstr)+1; i++ {
if title[i:i+len(expectedSubstr)] == expectedSubstr {
found = true
break
}
}
if !found {
t.Errorf("Expected title to contain GID '%s', got '%s'", expectedSubstr, title)
}
}
func TestContextCancellation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mockStor := &mockStorage{name: "test-storage"}
mockProg := &mockProgress{}
task := NewTask(
"test-task-id",
ctx,
"test-gid",
[]string{"http://example.com/file.zip"},
nil, // nil client will cause Execute to fail/timeout
mockStor,
"/test/path",
mockProg,
)
// Just verify the task structure is valid
if task.ctx.Err() != nil {
t.Error("Context should not be cancelled yet")
}
// Wait for context to timeout
<-ctx.Done()
if ctx.Err() == nil {
t.Error("Context should be cancelled after timeout")
}
}

View File

@@ -1,5 +1,5 @@
package tasktype package tasktype
//go:generate go-enum --values --names --flag --nocase //go:generate go-enum --values --names --flag --nocase
// ENUM(tgfiles,tphpics,parseditem,directlinks) // ENUM(tgfiles,tphpics,parseditem,directlinks,aria2)
type TaskType string type TaskType string

View File

@@ -20,6 +20,8 @@ const (
TaskTypeParseditem TaskType = "parseditem" TaskTypeParseditem TaskType = "parseditem"
// TaskTypeDirectlinks is a TaskType of type directlinks. // TaskTypeDirectlinks is a TaskType of type directlinks.
TaskTypeDirectlinks TaskType = "directlinks" TaskTypeDirectlinks TaskType = "directlinks"
// TaskTypeAria2 is a TaskType of type aria2.
TaskTypeAria2 TaskType = "aria2"
) )
var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", ")) var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", "))
@@ -29,6 +31,7 @@ var _TaskTypeNames = []string{
string(TaskTypeTphpics), string(TaskTypeTphpics),
string(TaskTypeParseditem), string(TaskTypeParseditem),
string(TaskTypeDirectlinks), string(TaskTypeDirectlinks),
string(TaskTypeAria2),
} }
// TaskTypeNames returns a list of possible string values of TaskType. // TaskTypeNames returns a list of possible string values of TaskType.
@@ -45,6 +48,7 @@ func TaskTypeValues() []TaskType {
TaskTypeTphpics, TaskTypeTphpics,
TaskTypeParseditem, TaskTypeParseditem,
TaskTypeDirectlinks, TaskTypeDirectlinks,
TaskTypeAria2,
} }
} }
@@ -65,6 +69,7 @@ var _TaskTypeValue = map[string]TaskType{
"tphpics": TaskTypeTphpics, "tphpics": TaskTypeTphpics,
"parseditem": TaskTypeParseditem, "parseditem": TaskTypeParseditem,
"directlinks": TaskTypeDirectlinks, "directlinks": TaskTypeDirectlinks,
"aria2": TaskTypeAria2,
} }
// ParseTaskType attempts to convert a string to a TaskType. // ParseTaskType attempts to convert a string to a TaskType.

View File

@@ -45,6 +45,8 @@ type Add struct {
ParsedItem *parser.Item ParsedItem *parser.Item
// directlinks // directlinks
DirectLinks []string DirectLinks []string
// aria2
Aria2URIs []string
} }
type SetDefaultStorage struct { type SetDefaultStorage struct {