feat: implement task event system for progress tracking and reporting

This commit is contained in:
krau
2026-06-25 17:52:38 +08:00
parent 9c2e70ed43
commit df3c568bb8
17 changed files with 397 additions and 214 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/parser"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
@@ -68,9 +69,14 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
taskID := task.TaskID()
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
info := RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
err := core.AddTask(f.ctx, NewExecutableWrapper(task))
// Inject the progress sink into the context so the task's Emit calls update
// the API store (and fire the webhook on terminal states) without the task
// knowing about the API.
taskCtx := taskevent.WithSink(f.ctx, info)
err := core.AddTask(taskCtx, task)
if err != nil {
DeleteTask(taskID)
return fmt.Errorf("failed to add task: %w", err)

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
@@ -117,7 +118,7 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}
// 取消任务
// Cancel the task; the terminal status is set via the task event stream.
if err := core.CancelTask(r.Context(), taskID); err != nil {
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
return
@@ -184,27 +185,45 @@ func extractTaskIDFromPath(path string) string {
return parts[3]
}
// convertTaskProgressToResponse 将任务进度转换为响应格式
// convertTaskProgressToResponse renders a task's current state, computing
// percent and speed from the snapshot taken under the task's mutex.
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
status, total, downloaded, totalFiles, downloadedFiles, startedAt, errMsg, updatedAt := task.snapshot()
resp := TaskInfoResponse{
TaskID: task.TaskID,
Type: tasktype.TaskType(task.Type),
Status: task.Status,
Status: status,
Title: task.Title,
Storage: task.Storage,
Path: task.Path,
Error: task.Error,
Error: errMsg,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
UpdatedAt: updatedAt,
}
// 计算进度
if task.TotalBytes > 0 {
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
var percent float64
var speedMBPS float64
if total > 0 {
percent = float64(downloaded) * 100 / float64(total)
} else if totalFiles > 0 {
percent = float64(downloadedFiles) * 100 / float64(totalFiles)
}
if !startedAt.IsZero() {
elapsed := time.Since(startedAt).Seconds()
if elapsed > 0 {
speedMBPS = float64(downloaded) / elapsed / (1024 * 1024)
}
}
if total > 0 || totalFiles > 0 {
resp.Progress = &TaskProgress{
TotalBytes: task.TotalBytes,
DownloadedBytes: task.DownloadedBytes,
TotalBytes: total,
DownloadedBytes: downloaded,
TotalFiles: totalFiles,
DownloadedFiles: downloadedFiles,
Percent: percent,
SpeedMBPS: speedMBPS,
}
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
)
// setupTestServer creates a test server with handlers
@@ -403,32 +404,38 @@ func TestConcurrentProgressStore(t *testing.T) {
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates
func TestProgressTrackerConcurrentUpdates(t *testing.T) {
tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
tracker.OnStart(10000, 10)
info := RegisterTask("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
info.Emit(taskevent.Event{TaskID: "concurrent-progress", Phase: taskevent.PhaseStart, TotalBytes: 10000})
var wg sync.WaitGroup
numGoroutines := 50
updatesPerGoroutine := 100
// Concurrent progress updates
// Concurrent progress updates via the Sink interface
for i := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := range updatesPerGoroutine {
tracker.OnProgress(int64(id*updatesPerGoroutine+j), j)
info.Emit(taskevent.Event{
TaskID: "concurrent-progress",
Phase: taskevent.PhaseProgress,
DownloadedBytes: int64(id*updatesPerGoroutine + j),
TotalBytes: 10000,
})
}
}(i)
}
wg.Wait()
info := tracker.GetInfo()
if info.Status != TaskStatusRunning {
t.Errorf("expected status Running after concurrent updates, got %s", info.Status)
status, _, downloaded, _, _, _, _, _ := info.snapshot()
if status != TaskStatusRunning {
t.Errorf("expected status Running after concurrent updates, got %s", status)
}
if downloaded <= 0 {
t.Errorf("expected downloaded bytes > 0 after concurrent updates, got %d", downloaded)
}
// Note: Due to race conditions in the simple implementation,
// we can't reliably check exact values without proper synchronization
}
// TestTaskFactoryValidation tests TaskFactory parameter validation
@@ -526,8 +533,7 @@ func TestEdgeCases(t *testing.T) {
{
name: "Progress tracker with empty webhook",
fn: func(t *testing.T) {
tracker := NewProgressTracker("test", "type", "storage", "path", "title", "")
info := tracker.GetInfo()
info := RegisterTask("test-empty-webhook", "type", "storage", "path", "title", "")
if info.Webhook != "" {
t.Error("expected empty webhook")
}

View File

@@ -2,39 +2,48 @@ package api
import (
"sync"
"sync/atomic"
"time"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
)
// TaskProgressInfo 存储任务的进度信息
// TaskProgressInfo stores the progress of an API-submitted task. All fields are
// guarded by mu. It implements taskevent.Sink so the task layer can update it
// without knowing about the API.
type TaskProgressInfo struct {
TaskID string
Type string
Status TaskStatus
Title string
TotalBytes int64
DownloadedBytes int64
TotalFiles int
DownloadedFiles int
Storage string
Path string
Error string
CreatedAt time.Time
UpdatedAt time.Time
Webhook string
mu sync.Mutex
TaskID string
Type string
Status TaskStatus
Title string
TotalBytes int64
DownloadedBytes int64
TotalFiles int
DownloadedFiles int
Storage string
Path string
Error string
CreatedAt time.Time
UpdatedAt time.Time
StartedAt time.Time
Webhook string
webhookNotified bool
}
// progressStore 存储所有 API 任务的进度信息
// progressStore holds all API tasks. Entries are removed a fixed duration after
// they reach a terminal state to bound memory usage.
type progressStore struct {
mu sync.RWMutex
tasks map[string]*TaskProgressInfo
mu sync.RWMutex
tasks map[string]*TaskProgressInfo
retention time.Duration
}
var store = &progressStore{
tasks: make(map[string]*TaskProgressInfo),
tasks: make(map[string]*TaskProgressInfo),
retention: 24 * time.Hour,
}
// RegisterTask 注册一个新的 API 任务
// RegisterTask registers a new API task and returns its progress info.
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
info := &TaskProgressInfo{
TaskID: taskID,
@@ -55,7 +64,7 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP
return info
}
// GetTask 获取任务进度信息
// GetTask returns the progress info for a task.
func GetTask(taskID string) (*TaskProgressInfo, bool) {
store.mu.RLock()
defer store.mu.RUnlock()
@@ -63,7 +72,7 @@ func GetTask(taskID string) (*TaskProgressInfo, bool) {
return info, ok
}
// GetAllTasks 获取所有任务
// GetAllTasks returns all tracked tasks.
func GetAllTasks() []*TaskProgressInfo {
store.mu.RLock()
defer store.mu.RUnlock()
@@ -75,76 +84,133 @@ func GetAllTasks() []*TaskProgressInfo {
return tasks
}
// DeleteTask 删除任务记录
// DeleteTask removes a task record.
func DeleteTask(taskID string) {
store.mu.Lock()
defer store.mu.Unlock()
delete(store.tasks, taskID)
}
// UpdateStatus 更新任务状态
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
t.Status = status
t.UpdatedAt = time.Now()
// CleanupExpired removes tasks that reached a terminal state more than the
// store's retention duration ago. It is safe to call periodically.
func CleanupExpired() {
now := time.Now()
store.mu.Lock()
defer store.mu.Unlock()
for id, info := range store.tasks {
info.mu.Lock()
terminal := info.Status == TaskStatusCompleted || info.Status == TaskStatusFailed || info.Status == TaskStatusCancelled
stale := terminal && now.Sub(info.UpdatedAt) > store.retention
info.mu.Unlock()
if stale {
delete(store.tasks, id)
}
}
}
// SetError 设置错误信息
// StartCleanupLoop runs CleanupExpired on a fixed interval until ctx is done.
// It should be started once during API server initialization.
func StartCleanupLoop(ctx interface{ Done() <-chan struct{} }) {
go func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
CleanupExpired()
}
}
}()
}
// UpdateStatus sets the task status.
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
t.mu.Lock()
t.Status = status
t.UpdatedAt = time.Now()
if status == TaskStatusRunning && t.StartedAt.IsZero() {
t.StartedAt = t.UpdatedAt
}
t.mu.Unlock()
}
// SetError marks the task failed with an error message.
func (t *TaskProgressInfo) SetError(err string) {
t.mu.Lock()
t.Error = err
t.Status = TaskStatusFailed
t.UpdatedAt = time.Now()
t.mu.Unlock()
}
// ProgressTracker 用于 API 任务的进度追踪
type ProgressTracker struct {
info *TaskProgressInfo
// snapshot returns a point-in-time copy of the fields needed to render a
// response, so callers never touch the mutex directly.
func (t *TaskProgressInfo) snapshot() (status TaskStatus, total, downloaded int64, totalFiles, downloadedFiles int, startedAt time.Time, err string, updatedAt time.Time) {
t.mu.Lock()
defer t.mu.Unlock()
return t.Status, t.TotalBytes, t.DownloadedBytes, t.TotalFiles, t.DownloadedFiles, t.StartedAt, t.Error, t.UpdatedAt
}
// NewProgressTracker 创建新的进度追踪器
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
info := RegisterTask(taskID, taskType, storage, path, title, webhook)
return &ProgressTracker{info: info}
}
// OnStart 任务开始
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {
p.info.Status = TaskStatusRunning
p.info.TotalBytes = totalBytes
p.info.TotalFiles = totalFiles
p.info.UpdatedAt = time.Now()
}
// OnProgress 进度更新
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {
atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes)
p.info.DownloadedFiles = downloadedFiles
p.info.UpdatedAt = time.Now()
}
// OnDone 任务完成
func (p *ProgressTracker) OnDone(err error) {
if err != nil {
p.info.Status = TaskStatusFailed
p.info.Error = err.Error()
} else {
p.info.Status = TaskStatusCompleted
// Emit implements taskevent.Sink. It translates task lifecycle events into
// status/progress updates and fires the webhook on terminal transitions.
func (t *TaskProgressInfo) Emit(e taskevent.Event) {
t.mu.Lock()
switch e.Phase {
case taskevent.PhaseStart:
t.Status = TaskStatusRunning
if t.StartedAt.IsZero() {
t.StartedAt = time.Now()
}
if e.TotalBytes > 0 {
t.TotalBytes = e.TotalBytes
}
case taskevent.PhaseProgress:
t.Status = TaskStatusRunning
if e.TotalBytes > 0 {
t.TotalBytes = e.TotalBytes
}
t.DownloadedBytes = e.DownloadedBytes
if e.TotalFiles > 0 {
t.TotalFiles = e.TotalFiles
}
if e.DownloadedFiles > 0 {
t.DownloadedFiles = e.DownloadedFiles
}
case taskevent.PhaseDone:
if e.Err != nil {
t.Status = TaskStatusFailed
t.Error = e.Err.Error()
} else {
t.Status = TaskStatusCompleted
}
}
t.UpdatedAt = time.Now()
notify := t.Webhook != "" && !t.webhookNotified && (t.Status == TaskStatusCompleted || t.Status == TaskStatusFailed)
if notify {
t.webhookNotified = true
}
t.mu.Unlock()
if notify {
payload := CreateWebhookPayload(t.TaskID, t.Type, t.Status, t.Storage, t.Path, e.Err)
SendWebhook(nil, payload)
}
p.info.UpdatedAt = time.Now()
}
// GetInfo 获取任务信息
func (p *ProgressTracker) GetInfo() *TaskProgressInfo {
return p.info
// ProgressTracker is retained for compatibility but is no longer the primary
// progress path; taskevent drives updates now. These methods are safe no-ops
// when called on a nil receiver.
type ProgressTracker struct{}
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
return &ProgressTracker{}
}
// UpdateProgressBytes 更新下载字节数
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {
atomic.StoreInt64(&p.info.DownloadedBytes, bytes)
p.info.UpdatedAt = time.Now()
}
// UpdateProgressFiles 更新下载文件数
func (p *ProgressTracker) UpdateProgressFiles(files int) {
p.info.DownloadedFiles = files
p.info.UpdatedAt = time.Now()
}
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {}
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {}
func (p *ProgressTracker) OnDone(err error) {}
func (p *ProgressTracker) GetInfo() *TaskProgressInfo { return nil }
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {}
func (p *ProgressTracker) UpdateProgressFiles(files int) {}

View File

@@ -57,22 +57,19 @@ func NewServer(ctx context.Context) *Server {
// 404 处理
mux.HandleFunc("/", NotFoundHandler)
// 应用中间件
// Apply middleware chain.
var handler http.Handler = mux
// 添加认证中间件
// Apply auth middleware when a token is configured.
token := cfg.Token
if token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
}
if token != "" {
handler = AuthMiddleware()(handler)
}
// 添加日志中间件
// Add logging middleware.
handler = loggingMiddleware(handler)
// 添加恢复中间件
// Add recovery middleware.
handler = recoveryMiddleware(handler)
return &Server{
@@ -151,7 +148,8 @@ func (rw *responseWriter) WriteHeader(code int) {
rw.ResponseWriter.WriteHeader(code)
}
// Start 初始化并启动 API 服务器
// Start initializes and starts the API server. It refuses to start without a
// token, since an open download proxy is a security risk.
func Start(ctx context.Context) error {
cfg := config.C().API
@@ -160,9 +158,13 @@ func Start(ctx context.Context) error {
}
if cfg.Token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
return fmt.Errorf("API server is enabled but no token is set; refusing to start insecurely")
}
server := NewServer(ctx)
return server.Start(ctx)
if err := server.Start(ctx); err != nil {
return err
}
StartCleanupLoop(ctx)
return nil
}

View File

@@ -40,6 +40,8 @@ type CreateTaskResponse struct {
type TaskProgress struct {
TotalBytes int64 `json:"total_bytes,omitempty"`
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
TotalFiles int `json:"total_files,omitempty"`
DownloadedFiles int `json:"downloaded_files,omitempty"`
Percent float64 `json:"percent,omitempty"`
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
@@ -30,9 +29,14 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
webhookURL := info.Webhook
// 异步发送 webhook
// Async send with retries.
go func() {
logger := log.FromContext(ctx).With("task_id", payload.TaskID)
var logger *log.Logger
if ctx != nil {
logger = log.FromContext(ctx).With("task_id", payload.TaskID)
} else {
logger = log.Default().With("task_id", payload.TaskID)
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
@@ -72,7 +76,7 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
}()
}
// CreateWebhookPayload 创建 Webhook 负载
// CreateWebhookPayload creates a Webhook payload.
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
payload := &WebhookPayload{
TaskID: taskID,
@@ -93,38 +97,3 @@ func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, sto
return payload
}
// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调
func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error {
info, ok := GetTask(taskID)
if !ok {
return fmt.Errorf("task not found: %s", taskID)
}
err := fn()
// 确定任务状态
status := TaskStatusCompleted
if err != nil {
if err == context.Canceled {
status = TaskStatusCancelled
} else {
status = TaskStatusFailed
}
}
// 更新任务状态
if err != nil {
info.SetError(err.Error())
} else {
info.UpdateStatus(TaskStatusCompleted)
}
// 发送 webhook
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}

View File

@@ -1,58 +0,0 @@
package api
import (
"context"
"errors"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
type ExecutableWrapper struct {
inner core.Executable
}
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
return &ExecutableWrapper{inner: inner}
}
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
taskID := w.inner.TaskID()
if info, ok := GetTask(taskID); ok {
info.UpdateStatus(TaskStatusRunning)
}
err := w.inner.Execute(ctx)
info, ok := GetTask(taskID)
if !ok {
return err
}
var status TaskStatus
if err != nil {
if errors.Is(err, context.Canceled) {
status = TaskStatusCancelled
info.UpdateStatus(TaskStatusCancelled)
} else {
status = TaskStatusFailed
info.SetError(err.Error())
}
} else {
status = TaskStatusCompleted
info.UpdateStatus(TaskStatusCompleted)
}
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/queue"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
)
var queueInstance *queue.TaskQueue[Executable]
@@ -30,11 +31,14 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
break // queue closed and empty
}
exe := qtask.Data
taskCtx := qtask.Context()
logger.Infof("Processing task: %s", exe.TaskID())
if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil {
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseStart})
if err := ExecCommandString(taskCtx, execHooks.TaskBeforeStart); err != nil {
logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err)
}
if err := exe.Execute(qtask.Context()); err != nil {
err = exe.Execute(taskCtx)
if err != nil {
if errors.Is(err, context.Canceled) {
logger.Infof("Task %s was canceled", exe.TaskID())
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
@@ -52,6 +56,7 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err)
}
}
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseDone, Err: err})
qe.Done(qtask.ID)
<-semaphore
}

View File

@@ -6,12 +6,14 @@ import (
"fmt"
"os"
"path/filepath"
"strconv"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
)
// Execute implements core.Executable.
@@ -77,6 +79,12 @@ func (t *Task) waitForDownload(ctx context.Context) error {
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, status)
}
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: parseInt64(status.TotalLength),
DownloadedBytes: parseInt64(status.CompletedLength),
})
// Check if download is complete
if status.IsDownloadComplete() {
@@ -248,3 +256,16 @@ func (t *Task) cancelAria2Download() {
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
}
}
// parseInt64 parses an aria2 status string (decimal bytes) into int64,
// returning 0 on failure so it can be used directly in progress events.
func parseInt64(s string) int64 {
if s == "" {
return 0
}
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0
}
return n
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup"
)
@@ -62,8 +63,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
return elem.Storage.Save(uploadCtx, pr, elem.Path)
})
wr := ioutil.NewProgressWriter(pw, func(n int) {
t.downloaded.Add(int64(n))
downloaded := t.downloaded.Add(int64(n))
t.Progress.OnProgress(ctx, t)
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalSize,
DownloadedBytes: downloaded,
})
})
errg.Go(func() error {
defer pw.Close()
@@ -92,8 +99,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
}
}()
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
t.downloaded.Add(int64(n))
downloaded := t.downloaded.Add(int64(n))
t.Progress.OnProgress(ctx, t)
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalSize,
DownloadedBytes: downloaded,
})
})
_, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt)
if err != nil {

View File

@@ -15,6 +15,7 @@ import (
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup"
)
@@ -143,10 +144,16 @@ func (t *Task) processLink(ctx context.Context, file *File) error {
}
}()
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
t.downloadedBytes.Add(int64(n))
downloaded := t.downloadedBytes.Add(int64(n))
if t.Progress != nil {
t.Progress.OnProgress(ctx, t)
}
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalBytes,
DownloadedBytes: downloaded,
})
})
copyResultCh := make(chan error, 1)

View File

@@ -16,6 +16,7 @@ import (
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/parser"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup"
)
@@ -107,10 +108,16 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er
}
}()
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
t.downloadedBytes.Add(int64(n))
downloaded := t.downloadedBytes.Add(int64(n))
if t.progress != nil {
t.progress.OnProgress(ctx, t)
}
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalBytes,
DownloadedBytes: downloaded,
})
})
copyResultCh := make(chan error, 1)

View File

@@ -11,6 +11,7 @@ import (
"github.com/duke-git/lancet/v2/retry"
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup"
)
@@ -27,8 +28,14 @@ func (t *Task) Execute(ctx context.Context) error {
logger.Errorf("Error processing picture %s: %v", pic, err)
return fmt.Errorf("failed to process picture %s: %w", pic, err)
}
t.downloaded.Add(1)
downloaded := t.downloaded.Add(1)
t.progress.OnProgress(gctx, t)
taskevent.Emit(gctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalFiles: t.totalpics,
DownloadedFiles: int(downloaded),
})
return nil
})
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"io"
"sync/atomic"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
)
type ProgressWriterAt struct {
@@ -20,9 +22,16 @@ func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
if err != nil {
return 0, err
}
downloaded := w.downloaded.Add(int64(at))
if w.progress != nil {
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
}
taskevent.Emit(w.ctx, taskevent.Event{
TaskID: w.info.TaskID(),
Phase: taskevent.PhaseProgress,
TotalBytes: w.total,
DownloadedBytes: downloaded,
})
return at, nil
}
@@ -56,9 +65,16 @@ func (w *ProgressWriter) Write(p []byte) (int, error) {
if err != nil {
return 0, err
}
downloaded := w.downloaded.Add(int64(at))
if w.progress != nil {
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
}
taskevent.Emit(w.ctx, taskevent.Event{
TaskID: w.info.TaskID(),
Phase: taskevent.PhaseProgress,
TotalBytes: w.total,
DownloadedBytes: downloaded,
})
return at, nil
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"github.com/krau/SaveAny-Bot/storage"
"golang.org/x/sync/errgroup"
)
@@ -116,6 +117,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
t.uploaded.Add(size)
t.Progress.OnProgress(ctx, t)
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalSize,
DownloadedBytes: t.uploaded.Load(),
})
logger.Info("File uploaded successfully")
return nil

View File

@@ -0,0 +1,88 @@
// Package taskevent provides a decoupled, context-scoped event bus for task
// lifecycle progress. Producers (task implementations) emit events via Emit;
// consumers (e.g. the API progress store, the Telegram message editor) register
// as Sinks and are injected through context. This keeps the task layer free of
// any concrete progress-display dependency, so new task types gain progress
// reporting for free and new observers can be added without touching tasks.
package taskevent
import "context"
// Phase marks a stage in a task's lifecycle.
type Phase int
const (
PhaseStart Phase = iota
PhaseProgress
PhaseDone
)
func (p Phase) String() string {
switch p {
case PhaseStart:
return "start"
case PhaseProgress:
return "progress"
case PhaseDone:
return "done"
default:
return "unknown"
}
}
// Event describes a single progress observation for a task. Byte fields are
// populated by byte-stream tasks; file-count fields by count-based tasks. A
// task may fill whichever subset it has; observers ignore zero values.
type Event struct {
TaskID string
Phase Phase
TotalBytes int64
DownloadedBytes int64
TotalFiles int
DownloadedFiles int
Err error
}
// Sink receives task events. Implementations must be safe for concurrent use.
type Sink interface {
Emit(Event)
}
// SinkFunc is a function adapter for Sink.
type SinkFunc func(Event)
func (f SinkFunc) Emit(e Event) { f(e) }
type sinkKey struct{}
// WithSink returns a ctx carrying the given sinks. Multiple sinks can be passed
// and all will receive every emitted event. Sinks already present in ctx are
// preserved.
func WithSink(ctx context.Context, sinks ...Sink) context.Context {
if len(sinks) == 0 {
return ctx
}
var existing []Sink
if v, ok := ctx.Value(sinkKey{}).([]Sink); ok {
existing = v
}
merged := make([]Sink, 0, len(existing)+len(sinks))
merged = append(merged, existing...)
merged = append(merged, sinks...)
return context.WithValue(ctx, sinkKey{}, merged)
}
// Emit broadcasts an event to all sinks carried by ctx. It is a no-op when no
// sink is attached, so producers can call it unconditionally.
func Emit(ctx context.Context, e Event) {
if ctx == nil {
return
}
sinks, ok := ctx.Value(sinkKey{}).([]Sink)
if !ok {
return
}
for _, s := range sinks {
s.Emit(e)
}
}