feat: basic aria2 integration
This commit is contained in:
@@ -90,6 +90,15 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error {
|
|||||||
shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID)
|
shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID)
|
||||||
case tasktype.TaskTypeDirectlinks:
|
case tasktype.TaskTypeDirectlinks:
|
||||||
shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID)
|
shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID)
|
||||||
|
case tasktype.TaskTypeAria2:
|
||||||
|
client := GetAria2Client()
|
||||||
|
if client == nil {
|
||||||
|
ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, i18n.T(i18nk.BotMsgAria2ErrorAria2ClientInitFailed, map[string]any{
|
||||||
|
"Error": "aria2 client not initialized",
|
||||||
|
})))
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unexcept task type: %s", data.TaskType)
|
return fmt.Errorf("unexcept task type: %s", data.TaskType)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,11 @@ var aria2ClientInitOnce sync.Once
|
|||||||
var aria2ClientInitErr error
|
var aria2ClientInitErr error
|
||||||
var aria2Client *aria2.Client
|
var aria2Client *aria2.Client
|
||||||
|
|
||||||
|
// GetAria2Client returns the shared aria2 client instance
|
||||||
|
func GetAria2Client() *aria2.Client {
|
||||||
|
return aria2Client
|
||||||
|
}
|
||||||
|
|
||||||
func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
||||||
if !config.C().Aria2.Enable {
|
if !config.C().Aria2.Enable {
|
||||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil)
|
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil)
|
||||||
@@ -78,7 +83,9 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
|||||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil)
|
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
logger.Debug("Adding aria2 download", "links", links)
|
logger.Debug("Preparing aria2 download", "links", links)
|
||||||
|
|
||||||
|
// Initialize aria2 client to check connection
|
||||||
aria2ClientInitOnce.Do(func() {
|
aria2ClientInitOnce.Do(func() {
|
||||||
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)
|
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)
|
||||||
})
|
})
|
||||||
@@ -89,17 +96,18 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
|||||||
})), nil)
|
})), nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
gid, err := aria2Client.AddURI(ctx, links, nil)
|
|
||||||
|
// Build storage selection keyboard (don't add to aria2 yet)
|
||||||
|
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
|
||||||
|
TaskType: tasktype.TaskTypeAria2,
|
||||||
|
Aria2URIs: links,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to add aria2 download", "error", err)
|
return err
|
||||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
|
|
||||||
"Error": err.Error(),
|
|
||||||
})), nil)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
logger.Info("Aria2 download added", "gid", gid)
|
|
||||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoAria2DownloadAdded, map[string]any{
|
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoSelectStorage)), &ext.ReplyOpts{
|
||||||
"GID": gid,
|
Markup: markup,
|
||||||
})), nil)
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add)
|
|||||||
ParsedItem: adddata.ParsedItem,
|
ParsedItem: adddata.ParsedItem,
|
||||||
|
|
||||||
DirectLinks: adddata.DirectLinks,
|
DirectLinks: adddata.DirectLinks,
|
||||||
|
|
||||||
|
Aria2URIs: adddata.Aria2URIs,
|
||||||
}
|
}
|
||||||
dataid := xid.New().String()
|
dataid := xid.New().String()
|
||||||
err := cache.Set(dataid, data)
|
err := cache.Set(dataid, data)
|
||||||
|
|||||||
65
client/bot/handlers/utils/shortcut/aria2.go
Normal file
65
client/bot/handlers/utils/shortcut/aria2.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package shortcut
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/celestix/gotgproto/dispatcher"
|
||||||
|
"github.com/celestix/gotgproto/ext"
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||||
|
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||||
|
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||||
|
"github.com/krau/SaveAny-Bot/core"
|
||||||
|
"github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateAndAddAria2TaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, uris []string, aria2Client *aria2.Client, msgID int, userID int64) error {
|
||||||
|
logger := log.FromContext(ctx)
|
||||||
|
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
|
||||||
|
|
||||||
|
// Now add to aria2 after user selected storage
|
||||||
|
logger.Infof("Adding download to aria2, uris type: %T, value: %+v", uris, uris)
|
||||||
|
|
||||||
|
// Ensure uris is valid
|
||||||
|
if len(uris) == 0 {
|
||||||
|
logger.Error("URIs list is empty")
|
||||||
|
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||||
|
ID: msgID,
|
||||||
|
Message: i18n.T(i18nk.BotMsgDlErrorNoValidLinks, nil),
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
gid, err := aria2Client.AddURI(ctx, uris, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to add aria2 download: %s", err)
|
||||||
|
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||||
|
ID: msgID,
|
||||||
|
Message: i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
|
||||||
|
"Error": err.Error(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
logger.Infof("Aria2 download added with GID: %s", gid)
|
||||||
|
|
||||||
|
// Create task with the GID
|
||||||
|
task := aria2dl.NewTask(xid.New().String(), injectCtx, gid, uris, aria2Client, stor, stor.JoinStoragePath(dirPath), aria2dl.NewProgress(msgID, userID))
|
||||||
|
if err := core.AddTask(injectCtx, task); err != nil {
|
||||||
|
logger.Errorf("Failed to add task: %s", err)
|
||||||
|
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||||
|
ID: msgID,
|
||||||
|
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
|
||||||
|
"Error": err.Error(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||||
|
ID: msgID,
|
||||||
|
Message: i18n.T(i18nk.BotMsgCommonInfoTaskAdded, nil),
|
||||||
|
})
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ const (
|
|||||||
BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled"
|
BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled"
|
||||||
BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download"
|
BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download"
|
||||||
BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added"
|
BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added"
|
||||||
|
BotMsgAria2InfoSelectStorage Key = "bot.msg.aria2.info_select_storage"
|
||||||
BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed"
|
BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed"
|
||||||
BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested"
|
BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested"
|
||||||
BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task"
|
BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task"
|
||||||
@@ -127,15 +128,20 @@ const (
|
|||||||
BotMsgParserInfoInstallPluginSuccess Key = "bot.msg.parser.info_install_plugin_success"
|
BotMsgParserInfoInstallPluginSuccess Key = "bot.msg.parser.info_install_plugin_success"
|
||||||
BotMsgParserPluginNotEnabled Key = "bot.msg.parser.plugin_not_enabled"
|
BotMsgParserPluginNotEnabled Key = "bot.msg.parser.plugin_not_enabled"
|
||||||
BotMsgParserPromptReplyWithParserFile Key = "bot.msg.parser.prompt_reply_with_parser_file"
|
BotMsgParserPromptReplyWithParserFile Key = "bot.msg.parser.prompt_reply_with_parser_file"
|
||||||
|
BotMsgProgressAria2Done Key = "bot.msg.progress.aria2_done"
|
||||||
|
BotMsgProgressAria2Downloading Key = "bot.msg.progress.aria2_downloading"
|
||||||
|
BotMsgProgressAria2Start Key = "bot.msg.progress.aria2_start"
|
||||||
BotMsgProgressAvgSpeedPrefix Key = "bot.msg.progress.avg_speed_prefix"
|
BotMsgProgressAvgSpeedPrefix Key = "bot.msg.progress.avg_speed_prefix"
|
||||||
BotMsgProgressBatchDonePrefix Key = "bot.msg.progress.batch_done_prefix"
|
BotMsgProgressBatchDonePrefix Key = "bot.msg.progress.batch_done_prefix"
|
||||||
BotMsgProgressBatchProcessingPrefix Key = "bot.msg.progress.batch_processing_prefix"
|
BotMsgProgressBatchProcessingPrefix Key = "bot.msg.progress.batch_processing_prefix"
|
||||||
BotMsgProgressBatchStartPrefix Key = "bot.msg.progress.batch_start_prefix"
|
BotMsgProgressBatchStartPrefix Key = "bot.msg.progress.batch_start_prefix"
|
||||||
BotMsgProgressCurrentProgressPrefix Key = "bot.msg.progress.current_progress_prefix"
|
BotMsgProgressCurrentProgressPrefix Key = "bot.msg.progress.current_progress_prefix"
|
||||||
|
BotMsgProgressCurrentSpeedPrefix Key = "bot.msg.progress.current_speed_prefix"
|
||||||
BotMsgProgressDirectDonePrefix Key = "bot.msg.progress.direct_done_prefix"
|
BotMsgProgressDirectDonePrefix Key = "bot.msg.progress.direct_done_prefix"
|
||||||
BotMsgProgressDirectStart Key = "bot.msg.progress.direct_start"
|
BotMsgProgressDirectStart Key = "bot.msg.progress.direct_start"
|
||||||
BotMsgProgressDownloadDonePrefix Key = "bot.msg.progress.download_done_prefix"
|
BotMsgProgressDownloadDonePrefix Key = "bot.msg.progress.download_done_prefix"
|
||||||
BotMsgProgressDownloadFailedPrefix Key = "bot.msg.progress.download_failed_prefix"
|
BotMsgProgressDownloadFailedPrefix Key = "bot.msg.progress.download_failed_prefix"
|
||||||
|
BotMsgProgressDownloadedPrefix Key = "bot.msg.progress.downloaded_prefix"
|
||||||
BotMsgProgressDownloadingPrefix Key = "bot.msg.progress.downloading_prefix"
|
BotMsgProgressDownloadingPrefix Key = "bot.msg.progress.downloading_prefix"
|
||||||
BotMsgProgressErrorPrefix Key = "bot.msg.progress.error_prefix"
|
BotMsgProgressErrorPrefix Key = "bot.msg.progress.error_prefix"
|
||||||
BotMsgProgressFileNamePrefix Key = "bot.msg.progress.file_name_prefix"
|
BotMsgProgressFileNamePrefix Key = "bot.msg.progress.file_name_prefix"
|
||||||
|
|||||||
@@ -326,7 +326,19 @@ bot:
|
|||||||
direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)"
|
direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)"
|
||||||
file_name_prefix: "Filename: "
|
file_name_prefix: "Filename: "
|
||||||
error_prefix: "\nError: "
|
error_prefix: "\nError: "
|
||||||
|
aria2_start: "Waiting for Aria2 to complete download (GID: {{.GID}})..."
|
||||||
|
aria2_downloading: "Aria2 is downloading (GID: {{.GID}})\n"
|
||||||
|
aria2_done: "Aria2 download completed and saved (GID: {{.GID}})\n"
|
||||||
|
downloaded_prefix: "\nDownloaded: "
|
||||||
|
current_speed_prefix: "\nCurrent speed: "
|
||||||
syncpeers:
|
syncpeers:
|
||||||
start: "Starting to sync peers..."
|
start: "Starting to sync peers..."
|
||||||
done: "Peer sync completed, total {{.Count}} chats synced"
|
done: "Peer sync completed, total {{.Count}} chats synced"
|
||||||
failed: "Peer sync failed: {{.Error}}"
|
failed: "Peer sync failed: {{.Error}}"
|
||||||
|
aria2:
|
||||||
|
error_aria2_not_enabled: "Aria2 feature is not enabled in the configuration"
|
||||||
|
error_aria2_client_init_failed: "Aria2 client initialization failed: {{.Error}}"
|
||||||
|
info_adding_aria2_download: "Adding Aria2 download task..."
|
||||||
|
error_adding_aria2_download: "Failed to add Aria2 download task: {{.Error}}"
|
||||||
|
info_aria2_download_added: "Aria2 download task added, GID: {{.GID}}"
|
||||||
|
info_select_storage: "Please select storage, the task will be added to Aria2 download queue after selection"
|
||||||
|
|||||||
@@ -328,6 +328,11 @@ bot:
|
|||||||
direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)"
|
direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)"
|
||||||
file_name_prefix: "文件名: "
|
file_name_prefix: "文件名: "
|
||||||
error_prefix: "\n错误: "
|
error_prefix: "\n错误: "
|
||||||
|
aria2_start: "等待 Aria2 下载完成 (GID: {{.GID}})..."
|
||||||
|
aria2_downloading: "Aria2 正在下载 (GID: {{.GID}})\n"
|
||||||
|
aria2_done: "Aria2 下载完成并已转存 (GID: {{.GID}})\n"
|
||||||
|
downloaded_prefix: "\n已下载: "
|
||||||
|
current_speed_prefix: "\n当前速度: "
|
||||||
syncpeers:
|
syncpeers:
|
||||||
start: "正在同步对话列表..."
|
start: "正在同步对话列表..."
|
||||||
success: "对话列表同步完成, 共同步 {{.Count}} 个对话"
|
success: "对话列表同步完成, 共同步 {{.Count}} 个对话"
|
||||||
@@ -338,3 +343,4 @@ bot:
|
|||||||
info_adding_aria2_download: "正在添加 Aria2 下载任务..."
|
info_adding_aria2_download: "正在添加 Aria2 下载任务..."
|
||||||
error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}"
|
error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}"
|
||||||
info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}"
|
info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}"
|
||||||
|
info_select_storage: "请选择存储位置, 选择后将添加到 Aria2 下载队列"
|
||||||
@@ -18,6 +18,17 @@ token = ""
|
|||||||
enable = false
|
enable = false
|
||||||
url = "socks5://127.0.0.1:7890"
|
url = "socks5://127.0.0.1:7890"
|
||||||
|
|
||||||
|
# Aria2 配置
|
||||||
|
[aria2]
|
||||||
|
# 启用 Aria2 下载支持
|
||||||
|
enable = false
|
||||||
|
# Aria2 RPC URL
|
||||||
|
url = "http://localhost:6800/jsonrpc"
|
||||||
|
# Aria2 RPC Secret (如果配置了 rpc-secret)
|
||||||
|
secret = ""
|
||||||
|
# 转存完成后删除 Aria2 下载的本地文件
|
||||||
|
remove_after_transfer = true
|
||||||
|
|
||||||
# 存储列表
|
# 存储列表
|
||||||
[[storages]]
|
[[storages]]
|
||||||
# 标识名, 需要唯一
|
# 标识名, 需要唯一
|
||||||
|
|||||||
@@ -36,9 +36,10 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type aria2Config struct {
|
type aria2Config struct {
|
||||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||||
Url string `toml:"url" mapstructure:"url" json:"url"`
|
Url string `toml:"url" mapstructure:"url" json:"url"`
|
||||||
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
|
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
|
||||||
|
RemoveAfterTransfer bool `toml:"remove_after_transfer" mapstructure:"remove_after_transfer" json:"remove_after_transfer"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg = &Config{}
|
var cfg = &Config{}
|
||||||
|
|||||||
208
core/tasks/aria2dl/execute.go
Normal file
208
core/tasks/aria2dl/execute.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package aria2dl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Execute implements core.Executable.
|
||||||
|
func (t *Task) Execute(ctx context.Context) error {
|
||||||
|
logger := log.FromContext(ctx)
|
||||||
|
logger.Infof("Starting aria2 download task %s (GID: %s)", t.ID, t.GID)
|
||||||
|
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnStart(ctx, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var status *aria2.Status
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Warn("Aria2 task canceled")
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, ctx.Err())
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
// Try to get status from active/waiting queue first
|
||||||
|
status, err = t.Aria2Client.TellStatus(ctx, t.GID)
|
||||||
|
if err != nil {
|
||||||
|
// If GID not found in active queue, check stopped queue
|
||||||
|
logger.Debugf("Task not in active queue, checking stopped queue: %v", err)
|
||||||
|
stoppedTasks, stopErr := t.Aria2Client.TellStopped(ctx, -1, 100)
|
||||||
|
if stopErr != nil {
|
||||||
|
logger.Errorf("Failed to get stopped tasks: %v", stopErr)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get aria2 status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find our task in stopped queue
|
||||||
|
found := false
|
||||||
|
for _, task := range stoppedTasks {
|
||||||
|
if task.GID == t.GID {
|
||||||
|
status = &task
|
||||||
|
found = true
|
||||||
|
logger.Debugf("Found task in stopped queue with status: %s", status.Status)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
logger.Errorf("Task GID %s not found in active or stopped queue", t.GID)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("aria2 task not found: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("Aria2 GID %s status: %s, completed: %s/%s",
|
||||||
|
t.GID, status.Status, status.CompletedLength, status.TotalLength)
|
||||||
|
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnProgress(ctx, t, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if download is complete
|
||||||
|
if status.IsDownloadComplete() {
|
||||||
|
logger.Infof("Aria2 download completed for GID %s", t.GID)
|
||||||
|
goto TransferFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
if status.IsDownloadError() {
|
||||||
|
err := fmt.Errorf("aria2 download error: %s (code: %s)", status.ErrorMessage, status.ErrorCode)
|
||||||
|
logger.Errorf("Aria2 download failed: %v", err)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if removed
|
||||||
|
if status.IsDownloadRemoved() {
|
||||||
|
err := errors.New("aria2 download was removed")
|
||||||
|
logger.Error("Aria2 download was removed")
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TransferFiles:
|
||||||
|
// Get final status to get file list
|
||||||
|
status, err = t.Aria2Client.TellStatus(ctx, t.GID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to get final status: %v", err)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get final status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(status.Files) == 0 {
|
||||||
|
err := errors.New("no files in aria2 download")
|
||||||
|
logger.Error("No files in aria2 download")
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transfer files to storage
|
||||||
|
logger.Infof("Transferring %d file(s) to storage %s", len(status.Files), t.Storage.Name())
|
||||||
|
for _, file := range status.Files {
|
||||||
|
if file.Selected != "true" {
|
||||||
|
logger.Debugf("Skipping unselected file: %s", file.Path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if file exists
|
||||||
|
if _, err := os.Stat(file.Path); os.IsNotExist(err) {
|
||||||
|
logger.Errorf("Downloaded file not found: %s", file.Path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open file
|
||||||
|
f, err := os.Open(file.Path)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to open file %s: %v", file.Path, err)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to open file %s: %w", file.Path, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Get file info
|
||||||
|
fileInfo, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to stat file %s: %v", file.Path, err)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to stat file %s: %w", file.Path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set content length in context for storage
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
|
||||||
|
|
||||||
|
// Determine destination path
|
||||||
|
fileName := filepath.Base(file.Path)
|
||||||
|
destPath := filepath.Join(t.StorPath, fileName)
|
||||||
|
|
||||||
|
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
|
||||||
|
|
||||||
|
// Save to storage
|
||||||
|
err = t.Storage.Save(ctx, f, destPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to save file %s to storage: %v", fileName, err)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("Successfully transferred file %s", fileName)
|
||||||
|
|
||||||
|
// Optionally remove the local file after successful transfer
|
||||||
|
if config.C().Aria2.RemoveAfterTransfer {
|
||||||
|
if err := os.Remove(file.Path); err != nil {
|
||||||
|
logger.Warnf("Failed to remove local file %s: %v", file.Path, err)
|
||||||
|
} else {
|
||||||
|
logger.Debugf("Removed local file %s", file.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("Aria2 task %s completed successfully", t.ID)
|
||||||
|
if t.Progress != nil {
|
||||||
|
t.Progress.OnDone(ctx, t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up aria2 download result
|
||||||
|
_, err = t.Aria2Client.RemoveDownloadResult(ctx, t.GID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("Failed to remove aria2 download result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
189
core/tasks/aria2dl/progress.go
Normal file
189
core/tasks/aria2dl/progress.go
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
package aria2dl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
|
"github.com/gotd/td/telegram/message/entity"
|
||||||
|
"github.com/gotd/td/telegram/message/styling"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||||
|
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||||
|
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
|
||||||
|
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProgressTracker interface {
|
||||||
|
OnStart(ctx context.Context, task *Task)
|
||||||
|
OnProgress(ctx context.Context, task *Task, status *aria2.Status)
|
||||||
|
OnDone(ctx context.Context, task *Task, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Progress struct {
|
||||||
|
msgID int
|
||||||
|
chatID int64
|
||||||
|
start time.Time
|
||||||
|
lastUpdatePercent atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnStart implements ProgressTracker.
|
||||||
|
func (p *Progress) OnStart(ctx context.Context, task *Task) {
|
||||||
|
logger := log.FromContext(ctx)
|
||||||
|
p.start = time.Now()
|
||||||
|
p.lastUpdatePercent.Store(0)
|
||||||
|
logger.Infof("Aria2 task started: message_id=%d, chat_id=%d, gid=%s", p.msgID, p.chatID, task.GID)
|
||||||
|
ext := tgutil.ExtFromContext(ctx)
|
||||||
|
if ext == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entityBuilder := entity.Builder{}
|
||||||
|
if err := styling.Perform(&entityBuilder,
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Start, map[string]any{
|
||||||
|
"GID": task.GID,
|
||||||
|
}))); err != nil {
|
||||||
|
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text, entities := entityBuilder.Complete()
|
||||||
|
req := &tg.MessagesEditMessageRequest{
|
||||||
|
ID: p.msgID,
|
||||||
|
}
|
||||||
|
req.SetMessage(text)
|
||||||
|
req.SetEntities(entities)
|
||||||
|
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||||
|
Rows: []tg.KeyboardButtonRow{
|
||||||
|
{
|
||||||
|
Buttons: []tg.KeyboardButtonClass{
|
||||||
|
tgutil.BuildCancelButton(task.TaskID()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
ext.EditMessage(p.chatID, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnProgress implements ProgressTracker.
|
||||||
|
func (p *Progress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
|
||||||
|
totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64)
|
||||||
|
completedLength, _ := strconv.ParseInt(status.CompletedLength, 10, 64)
|
||||||
|
downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64)
|
||||||
|
|
||||||
|
if totalLength == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
percent := int((completedLength * 100) / totalLength)
|
||||||
|
if p.lastUpdatePercent.Load() == int32(percent) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.lastUpdatePercent.Store(int32(percent))
|
||||||
|
|
||||||
|
log.FromContext(ctx).Debugf("Aria2 progress update: %s, %d/%d", task.GID, completedLength, totalLength)
|
||||||
|
|
||||||
|
entityBuilder := entity.Builder{}
|
||||||
|
if err := styling.Perform(&entityBuilder,
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Downloading, map[string]any{
|
||||||
|
"GID": task.GID,
|
||||||
|
})),
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressDownloadedPrefix, nil)),
|
||||||
|
styling.Code(fmt.Sprintf("%.2f MB / %.2f MB", float64(completedLength)/(1024*1024), float64(totalLength)/(1024*1024))),
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentSpeedPrefix, nil)),
|
||||||
|
styling.Bold(fmt.Sprintf("%.2f MB/s", float64(downloadSpeed)/(1024*1024))),
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressAvgSpeedPrefix, nil)),
|
||||||
|
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(completedLength, p.start)/(1024*1024))),
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentProgressPrefix, nil)),
|
||||||
|
styling.Bold(fmt.Sprintf("%.2f%%", float64(percent))),
|
||||||
|
); err != nil {
|
||||||
|
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text, entities := entityBuilder.Complete()
|
||||||
|
req := &tg.MessagesEditMessageRequest{
|
||||||
|
ID: p.msgID,
|
||||||
|
}
|
||||||
|
req.SetMessage(text)
|
||||||
|
req.SetEntities(entities)
|
||||||
|
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||||
|
Rows: []tg.KeyboardButtonRow{
|
||||||
|
{
|
||||||
|
Buttons: []tg.KeyboardButtonClass{
|
||||||
|
tgutil.BuildCancelButton(task.TaskID()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
ext := tgutil.ExtFromContext(ctx)
|
||||||
|
if ext != nil {
|
||||||
|
ext.EditMessage(p.chatID, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDone implements ProgressTracker.
|
||||||
|
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
|
||||||
|
logger := log.FromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
logger.Infof("Aria2 task %s was canceled", task.TaskID())
|
||||||
|
ext := tgutil.ExtFromContext(ctx)
|
||||||
|
if ext != nil {
|
||||||
|
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||||
|
ID: p.msgID,
|
||||||
|
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
|
||||||
|
"TaskID": task.TaskID(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Errorf("Aria2 task %s failed: %s", task.TaskID(), err)
|
||||||
|
ext := tgutil.ExtFromContext(ctx)
|
||||||
|
if ext != nil {
|
||||||
|
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||||
|
ID: p.msgID,
|
||||||
|
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
|
||||||
|
"Error": err.Error(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Infof("Aria2 task %s completed successfully", task.TaskID())
|
||||||
|
|
||||||
|
entityBuilder := entity.Builder{}
|
||||||
|
if err := styling.Perform(&entityBuilder,
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Done, map[string]any{
|
||||||
|
"GID": task.GID,
|
||||||
|
})),
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
|
||||||
|
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
|
||||||
|
); err != nil {
|
||||||
|
logger.Errorf("Failed to build entities: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text, entities := entityBuilder.Complete()
|
||||||
|
req := &tg.MessagesEditMessageRequest{
|
||||||
|
ID: p.msgID,
|
||||||
|
}
|
||||||
|
req.SetMessage(text)
|
||||||
|
req.SetEntities(entities)
|
||||||
|
|
||||||
|
ext := tgutil.ExtFromContext(ctx)
|
||||||
|
if ext != nil {
|
||||||
|
ext.EditMessage(p.chatID, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ProgressTracker = (*Progress)(nil)
|
||||||
|
|
||||||
|
func NewProgress(msgID int, userID int64) ProgressTracker {
|
||||||
|
return &Progress{
|
||||||
|
msgID: msgID,
|
||||||
|
chatID: userID,
|
||||||
|
}
|
||||||
|
}
|
||||||
61
core/tasks/aria2dl/task.go
Normal file
61
core/tasks/aria2dl/task.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package aria2dl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/core"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||||
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ core.Executable = (*Task)(nil)
|
||||||
|
|
||||||
|
type Task struct {
|
||||||
|
ID string
|
||||||
|
ctx context.Context
|
||||||
|
GID string
|
||||||
|
URIs []string
|
||||||
|
Aria2Client *aria2.Client
|
||||||
|
Storage storage.Storage
|
||||||
|
StorPath string
|
||||||
|
Progress ProgressTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// Title implements core.Executable.
|
||||||
|
func (t *Task) Title() string {
|
||||||
|
return fmt.Sprintf("[%s](Aria2 GID:%s->%s:%s)", t.Type(), t.GID, t.Storage.Name(), t.StorPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type implements core.Executable.
|
||||||
|
func (t *Task) Type() tasktype.TaskType {
|
||||||
|
return tasktype.TaskTypeAria2
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskID implements core.Executable.
|
||||||
|
func (t *Task) TaskID() string {
|
||||||
|
return t.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTask(
|
||||||
|
id string,
|
||||||
|
ctx context.Context,
|
||||||
|
gid string,
|
||||||
|
uris []string,
|
||||||
|
aria2Client *aria2.Client,
|
||||||
|
stor storage.Storage,
|
||||||
|
storPath string,
|
||||||
|
progressTracker ProgressTracker,
|
||||||
|
) *Task {
|
||||||
|
return &Task{
|
||||||
|
ID: id,
|
||||||
|
ctx: ctx,
|
||||||
|
GID: gid,
|
||||||
|
URIs: uris,
|
||||||
|
Aria2Client: aria2Client,
|
||||||
|
Storage: stor,
|
||||||
|
StorPath: storPath,
|
||||||
|
Progress: progressTracker,
|
||||||
|
}
|
||||||
|
}
|
||||||
209
core/tasks/aria2dl/task_test.go
Normal file
209
core/tasks/aria2dl/task_test.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package aria2dl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
storconfig "github.com/krau/SaveAny-Bot/config/storage"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||||
|
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStorage struct {
|
||||||
|
name string
|
||||||
|
savePath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStorage) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStorage) Type() storenum.StorageType {
|
||||||
|
return storenum.StorageType("mock")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error {
|
||||||
|
m.savePath = path
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStorage) Exists(ctx context.Context, path string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStorage) JoinStoragePath(path string) string {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockProgress struct {
|
||||||
|
started bool
|
||||||
|
done bool
|
||||||
|
doneErr error
|
||||||
|
progress int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProgress) OnStart(ctx context.Context, task *Task) {
|
||||||
|
m.started = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProgress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
|
||||||
|
m.progress++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProgress) OnDone(ctx context.Context, task *Task, err error) {
|
||||||
|
m.done = true
|
||||||
|
m.doneErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskCreation(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
mockStor := &mockStorage{name: "test-storage"}
|
||||||
|
mockProg := &mockProgress{}
|
||||||
|
|
||||||
|
task := NewTask(
|
||||||
|
"test-task-id",
|
||||||
|
ctx,
|
||||||
|
"test-gid",
|
||||||
|
[]string{"http://example.com/file.zip"},
|
||||||
|
nil,
|
||||||
|
mockStor,
|
||||||
|
"/test/path",
|
||||||
|
mockProg,
|
||||||
|
)
|
||||||
|
|
||||||
|
if task.ID != "test-task-id" {
|
||||||
|
t.Errorf("Expected task ID to be 'test-task-id', got '%s'", task.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.GID != "test-gid" {
|
||||||
|
t.Errorf("Expected GID to be 'test-gid', got '%s'", task.GID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.Type() != tasktype.TaskTypeAria2 {
|
||||||
|
t.Errorf("Expected task type to be TaskTypeAria2, got '%s'", task.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.TaskID() != "test-task-id" {
|
||||||
|
t.Errorf("Expected TaskID() to return 'test-task-id', got '%s'", task.TaskID())
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.Storage.Name() != "test-storage" {
|
||||||
|
t.Errorf("Expected storage name to be 'test-storage', got '%s'", task.Storage.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProgressTracker(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
mockStor := &mockStorage{name: "test-storage"}
|
||||||
|
mockProg := &mockProgress{}
|
||||||
|
|
||||||
|
task := NewTask(
|
||||||
|
"test-task-id",
|
||||||
|
ctx,
|
||||||
|
"test-gid",
|
||||||
|
[]string{"http://example.com/file.zip"},
|
||||||
|
nil,
|
||||||
|
mockStor,
|
||||||
|
"/test/path",
|
||||||
|
mockProg,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test OnStart
|
||||||
|
mockProg.OnStart(ctx, task)
|
||||||
|
if !mockProg.started {
|
||||||
|
t.Error("Expected OnStart to set started to true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test OnProgress
|
||||||
|
status := &aria2.Status{
|
||||||
|
GID: "test-gid",
|
||||||
|
Status: "active",
|
||||||
|
TotalLength: "1000000",
|
||||||
|
CompletedLength: "500000",
|
||||||
|
DownloadSpeed: "100000",
|
||||||
|
}
|
||||||
|
mockProg.OnProgress(ctx, task, status)
|
||||||
|
if mockProg.progress != 1 {
|
||||||
|
t.Errorf("Expected progress to be 1, got %d", mockProg.progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test OnDone
|
||||||
|
mockProg.OnDone(ctx, task, nil)
|
||||||
|
if !mockProg.done {
|
||||||
|
t.Error("Expected OnDone to set done to true")
|
||||||
|
}
|
||||||
|
if mockProg.doneErr != nil {
|
||||||
|
t.Errorf("Expected doneErr to be nil, got %v", mockProg.doneErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskTitle(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
mockStor := &mockStorage{name: "test-storage"}
|
||||||
|
|
||||||
|
task := NewTask(
|
||||||
|
"test-task-id",
|
||||||
|
ctx,
|
||||||
|
"test-gid-123",
|
||||||
|
[]string{"http://example.com/file.zip"},
|
||||||
|
nil,
|
||||||
|
mockStor,
|
||||||
|
"/test/path",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
title := task.Title()
|
||||||
|
expectedSubstr := "test-gid-123"
|
||||||
|
if len(title) == 0 {
|
||||||
|
t.Error("Expected title to not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if title contains the GID
|
||||||
|
found := false
|
||||||
|
for i := 0; i < len(title)-len(expectedSubstr)+1; i++ {
|
||||||
|
if title[i:i+len(expectedSubstr)] == expectedSubstr {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected title to contain GID '%s', got '%s'", expectedSubstr, title)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextCancellation(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
mockStor := &mockStorage{name: "test-storage"}
|
||||||
|
mockProg := &mockProgress{}
|
||||||
|
|
||||||
|
task := NewTask(
|
||||||
|
"test-task-id",
|
||||||
|
ctx,
|
||||||
|
"test-gid",
|
||||||
|
[]string{"http://example.com/file.zip"},
|
||||||
|
nil, // nil client will cause Execute to fail/timeout
|
||||||
|
mockStor,
|
||||||
|
"/test/path",
|
||||||
|
mockProg,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Just verify the task structure is valid
|
||||||
|
if task.ctx.Err() != nil {
|
||||||
|
t.Error("Context should not be cancelled yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for context to timeout
|
||||||
|
<-ctx.Done()
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Error("Context should be cancelled after timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
package tasktype
|
package tasktype
|
||||||
|
|
||||||
//go:generate go-enum --values --names --flag --nocase
|
//go:generate go-enum --values --names --flag --nocase
|
||||||
// ENUM(tgfiles,tphpics,parseditem,directlinks)
|
// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2)
|
||||||
type TaskType string
|
type TaskType string
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ const (
|
|||||||
TaskTypeParseditem TaskType = "parseditem"
|
TaskTypeParseditem TaskType = "parseditem"
|
||||||
// TaskTypeDirectlinks is a TaskType of type directlinks.
|
// TaskTypeDirectlinks is a TaskType of type directlinks.
|
||||||
TaskTypeDirectlinks TaskType = "directlinks"
|
TaskTypeDirectlinks TaskType = "directlinks"
|
||||||
|
// TaskTypeAria2 is a TaskType of type aria2.
|
||||||
|
TaskTypeAria2 TaskType = "aria2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", "))
|
var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", "))
|
||||||
@@ -29,6 +31,7 @@ var _TaskTypeNames = []string{
|
|||||||
string(TaskTypeTphpics),
|
string(TaskTypeTphpics),
|
||||||
string(TaskTypeParseditem),
|
string(TaskTypeParseditem),
|
||||||
string(TaskTypeDirectlinks),
|
string(TaskTypeDirectlinks),
|
||||||
|
string(TaskTypeAria2),
|
||||||
}
|
}
|
||||||
|
|
||||||
// TaskTypeNames returns a list of possible string values of TaskType.
|
// TaskTypeNames returns a list of possible string values of TaskType.
|
||||||
@@ -45,6 +48,7 @@ func TaskTypeValues() []TaskType {
|
|||||||
TaskTypeTphpics,
|
TaskTypeTphpics,
|
||||||
TaskTypeParseditem,
|
TaskTypeParseditem,
|
||||||
TaskTypeDirectlinks,
|
TaskTypeDirectlinks,
|
||||||
|
TaskTypeAria2,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,6 +69,7 @@ var _TaskTypeValue = map[string]TaskType{
|
|||||||
"tphpics": TaskTypeTphpics,
|
"tphpics": TaskTypeTphpics,
|
||||||
"parseditem": TaskTypeParseditem,
|
"parseditem": TaskTypeParseditem,
|
||||||
"directlinks": TaskTypeDirectlinks,
|
"directlinks": TaskTypeDirectlinks,
|
||||||
|
"aria2": TaskTypeAria2,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseTaskType attempts to convert a string to a TaskType.
|
// ParseTaskType attempts to convert a string to a TaskType.
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ type Add struct {
|
|||||||
ParsedItem *parser.Item
|
ParsedItem *parser.Item
|
||||||
// directlinks
|
// directlinks
|
||||||
DirectLinks []string
|
DirectLinks []string
|
||||||
|
// aria2
|
||||||
|
Aria2URIs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type SetDefaultStorage struct {
|
type SetDefaultStorage struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user