mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-01 21:50:05 +08:00
feat: basic aria2 integration
This commit is contained in:
208
core/tasks/aria2dl/execute.go
Normal file
208
core/tasks/aria2dl/execute.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
)
|
||||
|
||||
// Execute implements core.Executable.
|
||||
func (t *Task) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Infof("Starting aria2 download task %s (GID: %s)", t.ID, t.GID)
|
||||
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnStart(ctx, t)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var status *aria2.Status
|
||||
var err error
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Warn("Aria2 task canceled")
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, ctx.Err())
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Try to get status from active/waiting queue first
|
||||
status, err = t.Aria2Client.TellStatus(ctx, t.GID)
|
||||
if err != nil {
|
||||
// If GID not found in active queue, check stopped queue
|
||||
logger.Debugf("Task not in active queue, checking stopped queue: %v", err)
|
||||
stoppedTasks, stopErr := t.Aria2Client.TellStopped(ctx, -1, 100)
|
||||
if stopErr != nil {
|
||||
logger.Errorf("Failed to get stopped tasks: %v", stopErr)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return fmt.Errorf("failed to get aria2 status: %w", err)
|
||||
}
|
||||
|
||||
// Find our task in stopped queue
|
||||
found := false
|
||||
for _, task := range stoppedTasks {
|
||||
if task.GID == t.GID {
|
||||
status = &task
|
||||
found = true
|
||||
logger.Debugf("Found task in stopped queue with status: %s", status.Status)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
logger.Errorf("Task GID %s not found in active or stopped queue", t.GID)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return fmt.Errorf("aria2 task not found: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debugf("Aria2 GID %s status: %s, completed: %s/%s",
|
||||
t.GID, status.Status, status.CompletedLength, status.TotalLength)
|
||||
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t, status)
|
||||
}
|
||||
|
||||
// Check if download is complete
|
||||
if status.IsDownloadComplete() {
|
||||
logger.Infof("Aria2 download completed for GID %s", t.GID)
|
||||
goto TransferFiles
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
if status.IsDownloadError() {
|
||||
err := fmt.Errorf("aria2 download error: %s (code: %s)", status.ErrorMessage, status.ErrorCode)
|
||||
logger.Errorf("Aria2 download failed: %v", err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if removed
|
||||
if status.IsDownloadRemoved() {
|
||||
err := errors.New("aria2 download was removed")
|
||||
logger.Error("Aria2 download was removed")
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TransferFiles:
|
||||
// Get final status to get file list
|
||||
status, err = t.Aria2Client.TellStatus(ctx, t.GID)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to get final status: %v", err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return fmt.Errorf("failed to get final status: %w", err)
|
||||
}
|
||||
|
||||
if len(status.Files) == 0 {
|
||||
err := errors.New("no files in aria2 download")
|
||||
logger.Error("No files in aria2 download")
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Transfer files to storage
|
||||
logger.Infof("Transferring %d file(s) to storage %s", len(status.Files), t.Storage.Name())
|
||||
for _, file := range status.Files {
|
||||
if file.Selected != "true" {
|
||||
logger.Debugf("Skipping unselected file: %s", file.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(file.Path); os.IsNotExist(err) {
|
||||
logger.Errorf("Downloaded file not found: %s", file.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
// Open file
|
||||
f, err := os.Open(file.Path)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to open file %s: %v", file.Path, err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return fmt.Errorf("failed to open file %s: %w", file.Path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Get file info
|
||||
fileInfo, err := f.Stat()
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to stat file %s: %v", file.Path, err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return fmt.Errorf("failed to stat file %s: %w", file.Path, err)
|
||||
}
|
||||
|
||||
// Set content length in context for storage
|
||||
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
|
||||
|
||||
// Determine destination path
|
||||
fileName := filepath.Base(file.Path)
|
||||
destPath := filepath.Join(t.StorPath, fileName)
|
||||
|
||||
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
|
||||
|
||||
// Save to storage
|
||||
err = t.Storage.Save(ctx, f, destPath)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to save file %s to storage: %v", fileName, err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
|
||||
}
|
||||
|
||||
logger.Infof("Successfully transferred file %s", fileName)
|
||||
|
||||
// Optionally remove the local file after successful transfer
|
||||
if config.C().Aria2.RemoveAfterTransfer {
|
||||
if err := os.Remove(file.Path); err != nil {
|
||||
logger.Warnf("Failed to remove local file %s: %v", file.Path, err)
|
||||
} else {
|
||||
logger.Debugf("Removed local file %s", file.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof("Aria2 task %s completed successfully", t.ID)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, nil)
|
||||
}
|
||||
|
||||
// Clean up aria2 download result
|
||||
_, err = t.Aria2Client.RemoveDownloadResult(ctx, t.GID)
|
||||
if err != nil {
|
||||
logger.Warnf("Failed to remove aria2 download result: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
189
core/tasks/aria2dl/progress.go
Normal file
189
core/tasks/aria2dl/progress.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
)
|
||||
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, task *Task)
|
||||
OnProgress(ctx context.Context, task *Task, status *aria2.Status)
|
||||
OnDone(ctx context.Context, task *Task, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
msgID int
|
||||
chatID int64
|
||||
start time.Time
|
||||
lastUpdatePercent atomic.Int32
|
||||
}
|
||||
|
||||
// OnStart implements ProgressTracker.
|
||||
func (p *Progress) OnStart(ctx context.Context, task *Task) {
|
||||
logger := log.FromContext(ctx)
|
||||
p.start = time.Now()
|
||||
p.lastUpdatePercent.Store(0)
|
||||
logger.Infof("Aria2 task started: message_id=%d, chat_id=%d, gid=%s", p.msgID, p.chatID, task.GID)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext == nil {
|
||||
return
|
||||
}
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Start, map[string]any{
|
||||
"GID": task.GID,
|
||||
}))); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(task.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
|
||||
// OnProgress implements ProgressTracker.
|
||||
func (p *Progress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
|
||||
totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64)
|
||||
completedLength, _ := strconv.ParseInt(status.CompletedLength, 10, 64)
|
||||
downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64)
|
||||
|
||||
if totalLength == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
percent := int((completedLength * 100) / totalLength)
|
||||
if p.lastUpdatePercent.Load() == int32(percent) {
|
||||
return
|
||||
}
|
||||
p.lastUpdatePercent.Store(int32(percent))
|
||||
|
||||
log.FromContext(ctx).Debugf("Aria2 progress update: %s, %d/%d", task.GID, completedLength, totalLength)
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Downloading, map[string]any{
|
||||
"GID": task.GID,
|
||||
})),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressDownloadedPrefix, nil)),
|
||||
styling.Code(fmt.Sprintf("%.2f MB / %.2f MB", float64(completedLength)/(1024*1024), float64(totalLength)/(1024*1024))),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentSpeedPrefix, nil)),
|
||||
styling.Bold(fmt.Sprintf("%.2f MB/s", float64(downloadSpeed)/(1024*1024))),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAvgSpeedPrefix, nil)),
|
||||
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(completedLength, p.start)/(1024*1024))),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentProgressPrefix, nil)),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", float64(percent))),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(task.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDone implements ProgressTracker.
|
||||
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
|
||||
logger := log.FromContext(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Infof("Aria2 task %s was canceled", task.TaskID())
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
|
||||
"TaskID": task.TaskID(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
logger.Errorf("Aria2 task %s failed: %s", task.TaskID(), err)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
|
||||
"Error": err.Error(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("Aria2 task %s completed successfully", task.TaskID())
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Done, map[string]any{
|
||||
"GID": task.GID,
|
||||
})),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
|
||||
); err != nil {
|
||||
logger.Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
var _ ProgressTracker = (*Progress)(nil)
|
||||
|
||||
func NewProgress(msgID int, userID int64) ProgressTracker {
|
||||
return &Progress{
|
||||
msgID: msgID,
|
||||
chatID: userID,
|
||||
}
|
||||
}
|
||||
61
core/tasks/aria2dl/task.go
Normal file
61
core/tasks/aria2dl/task.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
var _ core.Executable = (*Task)(nil)
|
||||
|
||||
type Task struct {
|
||||
ID string
|
||||
ctx context.Context
|
||||
GID string
|
||||
URIs []string
|
||||
Aria2Client *aria2.Client
|
||||
Storage storage.Storage
|
||||
StorPath string
|
||||
Progress ProgressTracker
|
||||
}
|
||||
|
||||
// Title implements core.Executable.
|
||||
func (t *Task) Title() string {
|
||||
return fmt.Sprintf("[%s](Aria2 GID:%s->%s:%s)", t.Type(), t.GID, t.Storage.Name(), t.StorPath)
|
||||
}
|
||||
|
||||
// Type implements core.Executable.
|
||||
func (t *Task) Type() tasktype.TaskType {
|
||||
return tasktype.TaskTypeAria2
|
||||
}
|
||||
|
||||
// TaskID implements core.Executable.
|
||||
func (t *Task) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func NewTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
gid string,
|
||||
uris []string,
|
||||
aria2Client *aria2.Client,
|
||||
stor storage.Storage,
|
||||
storPath string,
|
||||
progressTracker ProgressTracker,
|
||||
) *Task {
|
||||
return &Task{
|
||||
ID: id,
|
||||
ctx: ctx,
|
||||
GID: gid,
|
||||
URIs: uris,
|
||||
Aria2Client: aria2Client,
|
||||
Storage: stor,
|
||||
StorPath: storPath,
|
||||
Progress: progressTracker,
|
||||
}
|
||||
}
|
||||
209
core/tasks/aria2dl/task_test.go
Normal file
209
core/tasks/aria2dl/task_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
storconfig "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
)
|
||||
|
||||
type mockStorage struct {
|
||||
name string
|
||||
savePath string
|
||||
}
|
||||
|
||||
func (m *mockStorage) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *mockStorage) Type() storenum.StorageType {
|
||||
return storenum.StorageType("mock")
|
||||
}
|
||||
|
||||
func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error {
|
||||
m.savePath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStorage) Exists(ctx context.Context, path string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockStorage) JoinStoragePath(path string) string {
|
||||
return path
|
||||
}
|
||||
|
||||
type mockProgress struct {
|
||||
started bool
|
||||
done bool
|
||||
doneErr error
|
||||
progress int
|
||||
}
|
||||
|
||||
func (m *mockProgress) OnStart(ctx context.Context, task *Task) {
|
||||
m.started = true
|
||||
}
|
||||
|
||||
func (m *mockProgress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
|
||||
m.progress++
|
||||
}
|
||||
|
||||
func (m *mockProgress) OnDone(ctx context.Context, task *Task, err error) {
|
||||
m.done = true
|
||||
m.doneErr = err
|
||||
}
|
||||
|
||||
func TestTaskCreation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
mockProg := &mockProgress{}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil,
|
||||
mockStor,
|
||||
"/test/path",
|
||||
mockProg,
|
||||
)
|
||||
|
||||
if task.ID != "test-task-id" {
|
||||
t.Errorf("Expected task ID to be 'test-task-id', got '%s'", task.ID)
|
||||
}
|
||||
|
||||
if task.GID != "test-gid" {
|
||||
t.Errorf("Expected GID to be 'test-gid', got '%s'", task.GID)
|
||||
}
|
||||
|
||||
if task.Type() != tasktype.TaskTypeAria2 {
|
||||
t.Errorf("Expected task type to be TaskTypeAria2, got '%s'", task.Type())
|
||||
}
|
||||
|
||||
if task.TaskID() != "test-task-id" {
|
||||
t.Errorf("Expected TaskID() to return 'test-task-id', got '%s'", task.TaskID())
|
||||
}
|
||||
|
||||
if task.Storage.Name() != "test-storage" {
|
||||
t.Errorf("Expected storage name to be 'test-storage', got '%s'", task.Storage.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
mockProg := &mockProgress{}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil,
|
||||
mockStor,
|
||||
"/test/path",
|
||||
mockProg,
|
||||
)
|
||||
|
||||
// Test OnStart
|
||||
mockProg.OnStart(ctx, task)
|
||||
if !mockProg.started {
|
||||
t.Error("Expected OnStart to set started to true")
|
||||
}
|
||||
|
||||
// Test OnProgress
|
||||
status := &aria2.Status{
|
||||
GID: "test-gid",
|
||||
Status: "active",
|
||||
TotalLength: "1000000",
|
||||
CompletedLength: "500000",
|
||||
DownloadSpeed: "100000",
|
||||
}
|
||||
mockProg.OnProgress(ctx, task, status)
|
||||
if mockProg.progress != 1 {
|
||||
t.Errorf("Expected progress to be 1, got %d", mockProg.progress)
|
||||
}
|
||||
|
||||
// Test OnDone
|
||||
mockProg.OnDone(ctx, task, nil)
|
||||
if !mockProg.done {
|
||||
t.Error("Expected OnDone to set done to true")
|
||||
}
|
||||
if mockProg.doneErr != nil {
|
||||
t.Errorf("Expected doneErr to be nil, got %v", mockProg.doneErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskTitle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid-123",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil,
|
||||
mockStor,
|
||||
"/test/path",
|
||||
nil,
|
||||
)
|
||||
|
||||
title := task.Title()
|
||||
expectedSubstr := "test-gid-123"
|
||||
if len(title) == 0 {
|
||||
t.Error("Expected title to not be empty")
|
||||
}
|
||||
|
||||
// Check if title contains the GID
|
||||
found := false
|
||||
for i := 0; i < len(title)-len(expectedSubstr)+1; i++ {
|
||||
if title[i:i+len(expectedSubstr)] == expectedSubstr {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected title to contain GID '%s', got '%s'", expectedSubstr, title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
mockProg := &mockProgress{}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil, // nil client will cause Execute to fail/timeout
|
||||
mockStor,
|
||||
"/test/path",
|
||||
mockProg,
|
||||
)
|
||||
|
||||
// Just verify the task structure is valid
|
||||
if task.ctx.Err() != nil {
|
||||
t.Error("Context should not be cancelled yet")
|
||||
}
|
||||
|
||||
// Wait for context to timeout
|
||||
<-ctx.Done()
|
||||
if ctx.Err() == nil {
|
||||
t.Error("Context should be cancelled after timeout")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user