From df3c568bb8bdbf3eb7dd381d1fdd261129fed6b8 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Thu, 25 Jun 2026 17:52:38 +0800 Subject: [PATCH] feat: implement task event system for progress tracking and reporting --- api/factory.go | 10 +- api/handlers.go | 39 ++++-- api/handlers_test.go | 28 ++-- api/progress.go | 216 +++++++++++++++++++----------- api/server.go | 22 +-- api/types.go | 2 + api/webhook.go | 47 ++----- api/wrapper.go | 58 -------- core/core.go | 9 +- core/tasks/aria2dl/execute.go | 21 +++ core/tasks/batchtfile/execute.go | 17 ++- core/tasks/directlinks/execute.go | 9 +- core/tasks/parsed/execute.go | 9 +- core/tasks/telegraph/execute.go | 9 +- core/tasks/tfile/writer.go | 20 ++- core/tasks/transfer/execute.go | 7 + pkg/taskevent/taskevent.go | 88 ++++++++++++ 17 files changed, 397 insertions(+), 214 deletions(-) delete mode 100644 api/wrapper.go create mode 100644 pkg/taskevent/taskevent.go diff --git a/api/factory.go b/api/factory.go index 4c00a6b..85bb2f5 100644 --- a/api/factory.go +++ b/api/factory.go @@ -20,6 +20,7 @@ import ( "github.com/krau/SaveAny-Bot/pkg/aria2" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/parser" + "github.com/krau/SaveAny-Bot/pkg/taskevent" "github.com/krau/SaveAny-Bot/pkg/telegraph" "github.com/krau/SaveAny-Bot/storage" "github.com/rs/xid" @@ -68,9 +69,14 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error { taskID := task.TaskID() - RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook) + info := RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook) - err := core.AddTask(f.ctx, NewExecutableWrapper(task)) + // Inject the progress sink into the context so the task's Emit calls update + // the API store (and fire the webhook on terminal states) without the task + // knowing about the API. + taskCtx := taskevent.WithSink(f.ctx, info) + + err := core.AddTask(taskCtx, task) if err != nil { DeleteTask(taskID) return fmt.Errorf("failed to add task: %w", err) diff --git a/api/handlers.go b/api/handlers.go index 81946d0..5f4f83b 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "strings" + "time" "github.com/krau/SaveAny-Bot/core" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" @@ -117,7 +118,7 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) { return } - // 取消任务 + // Cancel the task; the terminal status is set via the task event stream. if err := core.CancelTask(r.Context(), taskID); err != nil { WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error()) return @@ -184,27 +185,45 @@ func extractTaskIDFromPath(path string) string { return parts[3] } -// convertTaskProgressToResponse 将任务进度转换为响应格式 +// convertTaskProgressToResponse renders a task's current state, computing +// percent and speed from the snapshot taken under the task's mutex. func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse { + status, total, downloaded, totalFiles, downloadedFiles, startedAt, errMsg, updatedAt := task.snapshot() + resp := TaskInfoResponse{ TaskID: task.TaskID, Type: tasktype.TaskType(task.Type), - Status: task.Status, + Status: status, Title: task.Title, Storage: task.Storage, Path: task.Path, - Error: task.Error, + Error: errMsg, CreatedAt: task.CreatedAt, - UpdatedAt: task.UpdatedAt, + UpdatedAt: updatedAt, } - // 计算进度 - if task.TotalBytes > 0 { - percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes) + var percent float64 + var speedMBPS float64 + if total > 0 { + percent = float64(downloaded) * 100 / float64(total) + } else if totalFiles > 0 { + percent = float64(downloadedFiles) * 100 / float64(totalFiles) + } + if !startedAt.IsZero() { + elapsed := time.Since(startedAt).Seconds() + if elapsed > 0 { + speedMBPS = float64(downloaded) / elapsed / (1024 * 1024) + } + } + + if total > 0 || totalFiles > 0 { resp.Progress = &TaskProgress{ - TotalBytes: task.TotalBytes, - DownloadedBytes: task.DownloadedBytes, + TotalBytes: total, + DownloadedBytes: downloaded, + TotalFiles: totalFiles, + DownloadedFiles: downloadedFiles, Percent: percent, + SpeedMBPS: speedMBPS, } } diff --git a/api/handlers_test.go b/api/handlers_test.go index 86452cd..f329d53 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/taskevent" ) // setupTestServer creates a test server with handlers @@ -403,32 +404,38 @@ func TestConcurrentProgressStore(t *testing.T) { // TestProgressTrackerConcurrentUpdates tests concurrent progress updates func TestProgressTrackerConcurrentUpdates(t *testing.T) { - tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "") - tracker.OnStart(10000, 10) + info := RegisterTask("concurrent-progress", "directlinks", "local", "downloads", "Test", "") + info.Emit(taskevent.Event{TaskID: "concurrent-progress", Phase: taskevent.PhaseStart, TotalBytes: 10000}) var wg sync.WaitGroup numGoroutines := 50 updatesPerGoroutine := 100 - // Concurrent progress updates + // Concurrent progress updates via the Sink interface for i := range numGoroutines { wg.Add(1) go func(id int) { defer wg.Done() for j := range updatesPerGoroutine { - tracker.OnProgress(int64(id*updatesPerGoroutine+j), j) + info.Emit(taskevent.Event{ + TaskID: "concurrent-progress", + Phase: taskevent.PhaseProgress, + DownloadedBytes: int64(id*updatesPerGoroutine + j), + TotalBytes: 10000, + }) } }(i) } wg.Wait() - info := tracker.GetInfo() - if info.Status != TaskStatusRunning { - t.Errorf("expected status Running after concurrent updates, got %s", info.Status) + status, _, downloaded, _, _, _, _, _ := info.snapshot() + if status != TaskStatusRunning { + t.Errorf("expected status Running after concurrent updates, got %s", status) + } + if downloaded <= 0 { + t.Errorf("expected downloaded bytes > 0 after concurrent updates, got %d", downloaded) } - // Note: Due to race conditions in the simple implementation, - // we can't reliably check exact values without proper synchronization } // TestTaskFactoryValidation tests TaskFactory parameter validation @@ -526,8 +533,7 @@ func TestEdgeCases(t *testing.T) { { name: "Progress tracker with empty webhook", fn: func(t *testing.T) { - tracker := NewProgressTracker("test", "type", "storage", "path", "title", "") - info := tracker.GetInfo() + info := RegisterTask("test-empty-webhook", "type", "storage", "path", "title", "") if info.Webhook != "" { t.Error("expected empty webhook") } diff --git a/api/progress.go b/api/progress.go index adac371..dcd3082 100644 --- a/api/progress.go +++ b/api/progress.go @@ -2,39 +2,48 @@ package api import ( "sync" - "sync/atomic" "time" + + "github.com/krau/SaveAny-Bot/pkg/taskevent" ) -// TaskProgressInfo 存储任务的进度信息 +// TaskProgressInfo stores the progress of an API-submitted task. All fields are +// guarded by mu. It implements taskevent.Sink so the task layer can update it +// without knowing about the API. type TaskProgressInfo struct { - TaskID string - Type string - Status TaskStatus - Title string - TotalBytes int64 - DownloadedBytes int64 - TotalFiles int - DownloadedFiles int - Storage string - Path string - Error string - CreatedAt time.Time - UpdatedAt time.Time - Webhook string + mu sync.Mutex + TaskID string + Type string + Status TaskStatus + Title string + TotalBytes int64 + DownloadedBytes int64 + TotalFiles int + DownloadedFiles int + Storage string + Path string + Error string + CreatedAt time.Time + UpdatedAt time.Time + StartedAt time.Time + Webhook string + webhookNotified bool } -// progressStore 存储所有 API 任务的进度信息 +// progressStore holds all API tasks. Entries are removed a fixed duration after +// they reach a terminal state to bound memory usage. type progressStore struct { - mu sync.RWMutex - tasks map[string]*TaskProgressInfo + mu sync.RWMutex + tasks map[string]*TaskProgressInfo + retention time.Duration } var store = &progressStore{ - tasks: make(map[string]*TaskProgressInfo), + tasks: make(map[string]*TaskProgressInfo), + retention: 24 * time.Hour, } -// RegisterTask 注册一个新的 API 任务 +// RegisterTask registers a new API task and returns its progress info. func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo { info := &TaskProgressInfo{ TaskID: taskID, @@ -55,7 +64,7 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP return info } -// GetTask 获取任务进度信息 +// GetTask returns the progress info for a task. func GetTask(taskID string) (*TaskProgressInfo, bool) { store.mu.RLock() defer store.mu.RUnlock() @@ -63,7 +72,7 @@ func GetTask(taskID string) (*TaskProgressInfo, bool) { return info, ok } -// GetAllTasks 获取所有任务 +// GetAllTasks returns all tracked tasks. func GetAllTasks() []*TaskProgressInfo { store.mu.RLock() defer store.mu.RUnlock() @@ -75,76 +84,133 @@ func GetAllTasks() []*TaskProgressInfo { return tasks } -// DeleteTask 删除任务记录 +// DeleteTask removes a task record. func DeleteTask(taskID string) { store.mu.Lock() defer store.mu.Unlock() delete(store.tasks, taskID) } -// UpdateStatus 更新任务状态 -func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) { - t.Status = status - t.UpdatedAt = time.Now() +// CleanupExpired removes tasks that reached a terminal state more than the +// store's retention duration ago. It is safe to call periodically. +func CleanupExpired() { + now := time.Now() + store.mu.Lock() + defer store.mu.Unlock() + for id, info := range store.tasks { + info.mu.Lock() + terminal := info.Status == TaskStatusCompleted || info.Status == TaskStatusFailed || info.Status == TaskStatusCancelled + stale := terminal && now.Sub(info.UpdatedAt) > store.retention + info.mu.Unlock() + if stale { + delete(store.tasks, id) + } + } } -// SetError 设置错误信息 +// StartCleanupLoop runs CleanupExpired on a fixed interval until ctx is done. +// It should be started once during API server initialization. +func StartCleanupLoop(ctx interface{ Done() <-chan struct{} }) { + go func() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + CleanupExpired() + } + } + }() +} + +// UpdateStatus sets the task status. +func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) { + t.mu.Lock() + t.Status = status + t.UpdatedAt = time.Now() + if status == TaskStatusRunning && t.StartedAt.IsZero() { + t.StartedAt = t.UpdatedAt + } + t.mu.Unlock() +} + +// SetError marks the task failed with an error message. func (t *TaskProgressInfo) SetError(err string) { + t.mu.Lock() t.Error = err t.Status = TaskStatusFailed t.UpdatedAt = time.Now() + t.mu.Unlock() } -// ProgressTracker 用于 API 任务的进度追踪 -type ProgressTracker struct { - info *TaskProgressInfo +// snapshot returns a point-in-time copy of the fields needed to render a +// response, so callers never touch the mutex directly. +func (t *TaskProgressInfo) snapshot() (status TaskStatus, total, downloaded int64, totalFiles, downloadedFiles int, startedAt time.Time, err string, updatedAt time.Time) { + t.mu.Lock() + defer t.mu.Unlock() + return t.Status, t.TotalBytes, t.DownloadedBytes, t.TotalFiles, t.DownloadedFiles, t.StartedAt, t.Error, t.UpdatedAt } -// NewProgressTracker 创建新的进度追踪器 -func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker { - info := RegisterTask(taskID, taskType, storage, path, title, webhook) - return &ProgressTracker{info: info} -} - -// OnStart 任务开始 -func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) { - p.info.Status = TaskStatusRunning - p.info.TotalBytes = totalBytes - p.info.TotalFiles = totalFiles - p.info.UpdatedAt = time.Now() -} - -// OnProgress 进度更新 -func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) { - atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes) - p.info.DownloadedFiles = downloadedFiles - p.info.UpdatedAt = time.Now() -} - -// OnDone 任务完成 -func (p *ProgressTracker) OnDone(err error) { - if err != nil { - p.info.Status = TaskStatusFailed - p.info.Error = err.Error() - } else { - p.info.Status = TaskStatusCompleted +// Emit implements taskevent.Sink. It translates task lifecycle events into +// status/progress updates and fires the webhook on terminal transitions. +func (t *TaskProgressInfo) Emit(e taskevent.Event) { + t.mu.Lock() + switch e.Phase { + case taskevent.PhaseStart: + t.Status = TaskStatusRunning + if t.StartedAt.IsZero() { + t.StartedAt = time.Now() + } + if e.TotalBytes > 0 { + t.TotalBytes = e.TotalBytes + } + case taskevent.PhaseProgress: + t.Status = TaskStatusRunning + if e.TotalBytes > 0 { + t.TotalBytes = e.TotalBytes + } + t.DownloadedBytes = e.DownloadedBytes + if e.TotalFiles > 0 { + t.TotalFiles = e.TotalFiles + } + if e.DownloadedFiles > 0 { + t.DownloadedFiles = e.DownloadedFiles + } + case taskevent.PhaseDone: + if e.Err != nil { + t.Status = TaskStatusFailed + t.Error = e.Err.Error() + } else { + t.Status = TaskStatusCompleted + } + } + t.UpdatedAt = time.Now() + notify := t.Webhook != "" && !t.webhookNotified && (t.Status == TaskStatusCompleted || t.Status == TaskStatusFailed) + if notify { + t.webhookNotified = true + } + t.mu.Unlock() + + if notify { + payload := CreateWebhookPayload(t.TaskID, t.Type, t.Status, t.Storage, t.Path, e.Err) + SendWebhook(nil, payload) } - p.info.UpdatedAt = time.Now() } -// GetInfo 获取任务信息 -func (p *ProgressTracker) GetInfo() *TaskProgressInfo { - return p.info +// ProgressTracker is retained for compatibility but is no longer the primary +// progress path; taskevent drives updates now. These methods are safe no-ops +// when called on a nil receiver. +type ProgressTracker struct{} + +func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker { + return &ProgressTracker{} } -// UpdateProgressBytes 更新下载字节数 -func (p *ProgressTracker) UpdateProgressBytes(bytes int64) { - atomic.StoreInt64(&p.info.DownloadedBytes, bytes) - p.info.UpdatedAt = time.Now() -} - -// UpdateProgressFiles 更新下载文件数 -func (p *ProgressTracker) UpdateProgressFiles(files int) { - p.info.DownloadedFiles = files - p.info.UpdatedAt = time.Now() -} +func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {} +func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {} +func (p *ProgressTracker) OnDone(err error) {} +func (p *ProgressTracker) GetInfo() *TaskProgressInfo { return nil } +func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {} +func (p *ProgressTracker) UpdateProgressFiles(files int) {} diff --git a/api/server.go b/api/server.go index be8b660..9418e1e 100644 --- a/api/server.go +++ b/api/server.go @@ -57,22 +57,19 @@ func NewServer(ctx context.Context) *Server { // 404 处理 mux.HandleFunc("/", NotFoundHandler) - // 应用中间件 + // Apply middleware chain. var handler http.Handler = mux - // 添加认证中间件 + // Apply auth middleware when a token is configured. token := cfg.Token - if token == "" { - log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!") - } if token != "" { handler = AuthMiddleware()(handler) } - // 添加日志中间件 + // Add logging middleware. handler = loggingMiddleware(handler) - // 添加恢复中间件 + // Add recovery middleware. handler = recoveryMiddleware(handler) return &Server{ @@ -151,7 +148,8 @@ func (rw *responseWriter) WriteHeader(code int) { rw.ResponseWriter.WriteHeader(code) } -// Start 初始化并启动 API 服务器 +// Start initializes and starts the API server. It refuses to start without a +// token, since an open download proxy is a security risk. func Start(ctx context.Context) error { cfg := config.C().API @@ -160,9 +158,13 @@ func Start(ctx context.Context) error { } if cfg.Token == "" { - log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!") + return fmt.Errorf("API server is enabled but no token is set; refusing to start insecurely") } server := NewServer(ctx) - return server.Start(ctx) + if err := server.Start(ctx); err != nil { + return err + } + StartCleanupLoop(ctx) + return nil } diff --git a/api/types.go b/api/types.go index 5462f6c..06d63c1 100644 --- a/api/types.go +++ b/api/types.go @@ -40,6 +40,8 @@ type CreateTaskResponse struct { type TaskProgress struct { TotalBytes int64 `json:"total_bytes,omitempty"` DownloadedBytes int64 `json:"downloaded_bytes,omitempty"` + TotalFiles int `json:"total_files,omitempty"` + DownloadedFiles int `json:"downloaded_files,omitempty"` Percent float64 `json:"percent,omitempty"` SpeedMBPS float64 `json:"speed_mbps,omitempty"` } diff --git a/api/webhook.go b/api/webhook.go index 16a50c4..7e4ad32 100644 --- a/api/webhook.go +++ b/api/webhook.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "fmt" "net/http" "time" @@ -30,9 +29,14 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) { webhookURL := info.Webhook - // 异步发送 webhook + // Async send with retries. go func() { - logger := log.FromContext(ctx).With("task_id", payload.TaskID) + var logger *log.Logger + if ctx != nil { + logger = log.FromContext(ctx).With("task_id", payload.TaskID) + } else { + logger = log.Default().With("task_id", payload.TaskID) + } payloadBytes, err := json.Marshal(payload) if err != nil { @@ -72,7 +76,7 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) { }() } -// CreateWebhookPayload 创建 Webhook 负载 +// CreateWebhookPayload creates a Webhook payload. func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload { payload := &WebhookPayload{ TaskID: taskID, @@ -93,38 +97,3 @@ func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, sto return payload } - -// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调 -func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error { - info, ok := GetTask(taskID) - if !ok { - return fmt.Errorf("task not found: %s", taskID) - } - - err := fn() - - // 确定任务状态 - status := TaskStatusCompleted - if err != nil { - if err == context.Canceled { - status = TaskStatusCancelled - } else { - status = TaskStatusFailed - } - } - - // 更新任务状态 - if err != nil { - info.SetError(err.Error()) - } else { - info.UpdateStatus(TaskStatusCompleted) - } - - // 发送 webhook - if info.Webhook != "" { - payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err) - SendWebhook(ctx, payload) - } - - return err -} diff --git a/api/wrapper.go b/api/wrapper.go deleted file mode 100644 index a3bb18f..0000000 --- a/api/wrapper.go +++ /dev/null @@ -1,58 +0,0 @@ -package api - -import ( - "context" - "errors" - - "github.com/krau/SaveAny-Bot/core" - "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" -) - -// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks. -type ExecutableWrapper struct { - inner core.Executable -} - -func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper { - return &ExecutableWrapper{inner: inner} -} - -func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() } -func (w *ExecutableWrapper) Title() string { return w.inner.Title() } -func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() } - -func (w *ExecutableWrapper) Execute(ctx context.Context) error { - taskID := w.inner.TaskID() - - if info, ok := GetTask(taskID); ok { - info.UpdateStatus(TaskStatusRunning) - } - - err := w.inner.Execute(ctx) - - info, ok := GetTask(taskID) - if !ok { - return err - } - - var status TaskStatus - if err != nil { - if errors.Is(err, context.Canceled) { - status = TaskStatusCancelled - info.UpdateStatus(TaskStatusCancelled) - } else { - status = TaskStatusFailed - info.SetError(err.Error()) - } - } else { - status = TaskStatusCompleted - info.UpdateStatus(TaskStatusCompleted) - } - - if info.Webhook != "" { - payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err) - SendWebhook(ctx, payload) - } - - return err -} diff --git a/core/core.go b/core/core.go index fe123c2..0bda7c9 100644 --- a/core/core.go +++ b/core/core.go @@ -8,6 +8,7 @@ import ( "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/queue" + "github.com/krau/SaveAny-Bot/pkg/taskevent" ) var queueInstance *queue.TaskQueue[Executable] @@ -30,11 +31,14 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan break // queue closed and empty } exe := qtask.Data + taskCtx := qtask.Context() logger.Infof("Processing task: %s", exe.TaskID()) - if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil { + taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseStart}) + if err := ExecCommandString(taskCtx, execHooks.TaskBeforeStart); err != nil { logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err) } - if err := exe.Execute(qtask.Context()); err != nil { + err = exe.Execute(taskCtx) + if err != nil { if errors.Is(err, context.Canceled) { logger.Infof("Task %s was canceled", exe.TaskID()) if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil { @@ -52,6 +56,7 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err) } } + taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseDone, Err: err}) qe.Done(qtask.ID) <-semaphore } diff --git a/core/tasks/aria2dl/execute.go b/core/tasks/aria2dl/execute.go index 43e53de..0ba47b6 100644 --- a/core/tasks/aria2dl/execute.go +++ b/core/tasks/aria2dl/execute.go @@ -6,12 +6,14 @@ import ( "fmt" "os" "path/filepath" + "strconv" "time" "github.com/charmbracelet/log" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/aria2" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" + "github.com/krau/SaveAny-Bot/pkg/taskevent" ) // Execute implements core.Executable. @@ -77,6 +79,12 @@ func (t *Task) waitForDownload(ctx context.Context) error { if t.Progress != nil { t.Progress.OnProgress(ctx, t, status) } + taskevent.Emit(ctx, taskevent.Event{ + TaskID: t.ID, + Phase: taskevent.PhaseProgress, + TotalBytes: parseInt64(status.TotalLength), + DownloadedBytes: parseInt64(status.CompletedLength), + }) // Check if download is complete if status.IsDownloadComplete() { @@ -248,3 +256,16 @@ func (t *Task) cancelAria2Download() { logger.Debugf("Failed to remove download result for %s: %v", t.GID, err) } } + +// parseInt64 parses an aria2 status string (decimal bytes) into int64, +// returning 0 on failure so it can be used directly in progress events. +func parseInt64(s string) int64 { + if s == "" { + return 0 + } + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0 + } + return n +} diff --git a/core/tasks/batchtfile/execute.go b/core/tasks/batchtfile/execute.go index 403d250..c686919 100644 --- a/core/tasks/batchtfile/execute.go +++ b/core/tasks/batchtfile/execute.go @@ -14,6 +14,7 @@ import ( "github.com/krau/SaveAny-Bot/common/utils/ioutil" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" + "github.com/krau/SaveAny-Bot/pkg/taskevent" "golang.org/x/sync/errgroup" ) @@ -62,8 +63,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { return elem.Storage.Save(uploadCtx, pr, elem.Path) }) wr := ioutil.NewProgressWriter(pw, func(n int) { - t.downloaded.Add(int64(n)) + downloaded := t.downloaded.Add(int64(n)) t.Progress.OnProgress(ctx, t) + taskevent.Emit(ctx, taskevent.Event{ + TaskID: t.ID, + Phase: taskevent.PhaseProgress, + TotalBytes: t.totalSize, + DownloadedBytes: downloaded, + }) }) errg.Go(func() error { defer pw.Close() @@ -92,8 +99,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { } }() wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) { - t.downloaded.Add(int64(n)) + downloaded := t.downloaded.Add(int64(n)) t.Progress.OnProgress(ctx, t) + taskevent.Emit(ctx, taskevent.Event{ + TaskID: t.ID, + Phase: taskevent.PhaseProgress, + TotalBytes: t.totalSize, + DownloadedBytes: downloaded, + }) }) _, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt) if err != nil { diff --git a/core/tasks/directlinks/execute.go b/core/tasks/directlinks/execute.go index c5f10d5..4ea50cc 100644 --- a/core/tasks/directlinks/execute.go +++ b/core/tasks/directlinks/execute.go @@ -15,6 +15,7 @@ import ( "github.com/krau/SaveAny-Bot/common/utils/ioutil" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" + "github.com/krau/SaveAny-Bot/pkg/taskevent" "golang.org/x/sync/errgroup" ) @@ -143,10 +144,16 @@ func (t *Task) processLink(ctx context.Context, file *File) error { } }() wr := ioutil.NewProgressWriter(cacheFile, func(n int) { - t.downloadedBytes.Add(int64(n)) + downloaded := t.downloadedBytes.Add(int64(n)) if t.Progress != nil { t.Progress.OnProgress(ctx, t) } + taskevent.Emit(ctx, taskevent.Event{ + TaskID: t.ID, + Phase: taskevent.PhaseProgress, + TotalBytes: t.totalBytes, + DownloadedBytes: downloaded, + }) }) copyResultCh := make(chan error, 1) diff --git a/core/tasks/parsed/execute.go b/core/tasks/parsed/execute.go index f97e6d9..9075e81 100644 --- a/core/tasks/parsed/execute.go +++ b/core/tasks/parsed/execute.go @@ -16,6 +16,7 @@ import ( "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/pkg/parser" + "github.com/krau/SaveAny-Bot/pkg/taskevent" "golang.org/x/sync/errgroup" ) @@ -107,10 +108,16 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er } }() wr := ioutil.NewProgressWriter(cacheFile, func(n int) { - t.downloadedBytes.Add(int64(n)) + downloaded := t.downloadedBytes.Add(int64(n)) if t.progress != nil { t.progress.OnProgress(ctx, t) } + taskevent.Emit(ctx, taskevent.Event{ + TaskID: t.ID, + Phase: taskevent.PhaseProgress, + TotalBytes: t.totalBytes, + DownloadedBytes: downloaded, + }) }) copyResultCh := make(chan error, 1) diff --git a/core/tasks/telegraph/execute.go b/core/tasks/telegraph/execute.go index dfa327b..0d67803 100644 --- a/core/tasks/telegraph/execute.go +++ b/core/tasks/telegraph/execute.go @@ -11,6 +11,7 @@ import ( "github.com/duke-git/lancet/v2/retry" "github.com/krau/SaveAny-Bot/common/utils/fsutil" "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/taskevent" "golang.org/x/sync/errgroup" ) @@ -27,8 +28,14 @@ func (t *Task) Execute(ctx context.Context) error { logger.Errorf("Error processing picture %s: %v", pic, err) return fmt.Errorf("failed to process picture %s: %w", pic, err) } - t.downloaded.Add(1) + downloaded := t.downloaded.Add(1) t.progress.OnProgress(gctx, t) + taskevent.Emit(gctx, taskevent.Event{ + TaskID: t.ID, + Phase: taskevent.PhaseProgress, + TotalFiles: t.totalpics, + DownloadedFiles: int(downloaded), + }) return nil }) } diff --git a/core/tasks/tfile/writer.go b/core/tasks/tfile/writer.go index 7d76ee5..09fc1c5 100644 --- a/core/tasks/tfile/writer.go +++ b/core/tasks/tfile/writer.go @@ -4,6 +4,8 @@ import ( "context" "io" "sync/atomic" + + "github.com/krau/SaveAny-Bot/pkg/taskevent" ) type ProgressWriterAt struct { @@ -20,9 +22,16 @@ func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) { if err != nil { return 0, err } + downloaded := w.downloaded.Add(int64(at)) if w.progress != nil { - w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total) + w.progress.OnProgress(w.ctx, w.info, downloaded, w.total) } + taskevent.Emit(w.ctx, taskevent.Event{ + TaskID: w.info.TaskID(), + Phase: taskevent.PhaseProgress, + TotalBytes: w.total, + DownloadedBytes: downloaded, + }) return at, nil } @@ -56,9 +65,16 @@ func (w *ProgressWriter) Write(p []byte) (int, error) { if err != nil { return 0, err } + downloaded := w.downloaded.Add(int64(at)) if w.progress != nil { - w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total) + w.progress.OnProgress(w.ctx, w.info, downloaded, w.total) } + taskevent.Emit(w.ctx, taskevent.Event{ + TaskID: w.info.TaskID(), + Phase: taskevent.PhaseProgress, + TotalBytes: w.total, + DownloadedBytes: downloaded, + }) return at, nil } diff --git a/core/tasks/transfer/execute.go b/core/tasks/transfer/execute.go index dc57da6..13cc524 100644 --- a/core/tasks/transfer/execute.go +++ b/core/tasks/transfer/execute.go @@ -11,6 +11,7 @@ import ( "github.com/charmbracelet/log" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" + "github.com/krau/SaveAny-Bot/pkg/taskevent" "github.com/krau/SaveAny-Bot/storage" "golang.org/x/sync/errgroup" ) @@ -116,6 +117,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { t.uploaded.Add(size) t.Progress.OnProgress(ctx, t) + taskevent.Emit(ctx, taskevent.Event{ + TaskID: t.ID, + Phase: taskevent.PhaseProgress, + TotalBytes: t.totalSize, + DownloadedBytes: t.uploaded.Load(), + }) logger.Info("File uploaded successfully") return nil diff --git a/pkg/taskevent/taskevent.go b/pkg/taskevent/taskevent.go new file mode 100644 index 0000000..cd2d32c --- /dev/null +++ b/pkg/taskevent/taskevent.go @@ -0,0 +1,88 @@ +// Package taskevent provides a decoupled, context-scoped event bus for task +// lifecycle progress. Producers (task implementations) emit events via Emit; +// consumers (e.g. the API progress store, the Telegram message editor) register +// as Sinks and are injected through context. This keeps the task layer free of +// any concrete progress-display dependency, so new task types gain progress +// reporting for free and new observers can be added without touching tasks. +package taskevent + +import "context" + +// Phase marks a stage in a task's lifecycle. +type Phase int + +const ( + PhaseStart Phase = iota + PhaseProgress + PhaseDone +) + +func (p Phase) String() string { + switch p { + case PhaseStart: + return "start" + case PhaseProgress: + return "progress" + case PhaseDone: + return "done" + default: + return "unknown" + } +} + +// Event describes a single progress observation for a task. Byte fields are +// populated by byte-stream tasks; file-count fields by count-based tasks. A +// task may fill whichever subset it has; observers ignore zero values. +type Event struct { + TaskID string + Phase Phase + TotalBytes int64 + DownloadedBytes int64 + TotalFiles int + DownloadedFiles int + Err error +} + +// Sink receives task events. Implementations must be safe for concurrent use. +type Sink interface { + Emit(Event) +} + +// SinkFunc is a function adapter for Sink. +type SinkFunc func(Event) + +func (f SinkFunc) Emit(e Event) { f(e) } + +type sinkKey struct{} + +// WithSink returns a ctx carrying the given sinks. Multiple sinks can be passed +// and all will receive every emitted event. Sinks already present in ctx are +// preserved. +func WithSink(ctx context.Context, sinks ...Sink) context.Context { + if len(sinks) == 0 { + return ctx + } + var existing []Sink + if v, ok := ctx.Value(sinkKey{}).([]Sink); ok { + existing = v + } + merged := make([]Sink, 0, len(existing)+len(sinks)) + merged = append(merged, existing...) + merged = append(merged, sinks...) + return context.WithValue(ctx, sinkKey{}, merged) +} + +// Emit broadcasts an event to all sinks carried by ctx. It is a no-op when no +// sink is attached, so producers can call it unconditionally. +func Emit(ctx context.Context, e Event) { + if ctx == nil { + return + } + sinks, ok := ctx.Value(sinkKey{}).([]Sink) + if !ok { + return + } + for _, s := range sinks { + s.Emit(e) + } +}