feat: implement task event system for progress tracking and reporting (#220)

This commit is contained in:
Krau
2026-06-25 21:37:36 +08:00
committed by GitHub
parent 9c2e70ed43
commit f02860ff3f
17 changed files with 397 additions and 214 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/parser"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
@@ -68,9 +69,14 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
taskID := task.TaskID()
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
info := RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
err := core.AddTask(f.ctx, NewExecutableWrapper(task))
// Inject the progress sink into the context so the task's Emit calls update
// the API store (and fire the webhook on terminal states) without the task
// knowing about the API.
taskCtx := taskevent.WithSink(f.ctx, info)
err := core.AddTask(taskCtx, task)
if err != nil {
DeleteTask(taskID)
return fmt.Errorf("failed to add task: %w", err)

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
@@ -117,7 +118,7 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}
// 取消任务
// Cancel the task; the terminal status is set via the task event stream.
if err := core.CancelTask(r.Context(), taskID); err != nil {
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
return
@@ -184,27 +185,45 @@ func extractTaskIDFromPath(path string) string {
return parts[3]
}
// convertTaskProgressToResponse 将任务进度转换为响应格式
// convertTaskProgressToResponse renders a task's current state, computing
// percent and speed from the snapshot taken under the task's mutex.
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
status, total, downloaded, totalFiles, downloadedFiles, startedAt, errMsg, updatedAt := task.snapshot()
resp := TaskInfoResponse{
TaskID: task.TaskID,
Type: tasktype.TaskType(task.Type),
Status: task.Status,
Status: status,
Title: task.Title,
Storage: task.Storage,
Path: task.Path,
Error: task.Error,
Error: errMsg,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
UpdatedAt: updatedAt,
}
// 计算进度
if task.TotalBytes > 0 {
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
var percent float64
var speedMBPS float64
if total > 0 {
percent = float64(downloaded) * 100 / float64(total)
} else if totalFiles > 0 {
percent = float64(downloadedFiles) * 100 / float64(totalFiles)
}
if !startedAt.IsZero() {
elapsed := time.Since(startedAt).Seconds()
if elapsed > 0 {
speedMBPS = float64(downloaded) / elapsed / (1024 * 1024)
}
}
if total > 0 || totalFiles > 0 {
resp.Progress = &TaskProgress{
TotalBytes: task.TotalBytes,
DownloadedBytes: task.DownloadedBytes,
TotalBytes: total,
DownloadedBytes: downloaded,
TotalFiles: totalFiles,
DownloadedFiles: downloadedFiles,
Percent: percent,
SpeedMBPS: speedMBPS,
}
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
)
// setupTestServer creates a test server with handlers
@@ -403,32 +404,38 @@ func TestConcurrentProgressStore(t *testing.T) {
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates
func TestProgressTrackerConcurrentUpdates(t *testing.T) {
tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
tracker.OnStart(10000, 10)
info := RegisterTask("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
info.Emit(taskevent.Event{TaskID: "concurrent-progress", Phase: taskevent.PhaseStart, TotalBytes: 10000})
var wg sync.WaitGroup
numGoroutines := 50
updatesPerGoroutine := 100
// Concurrent progress updates
// Concurrent progress updates via the Sink interface
for i := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := range updatesPerGoroutine {
tracker.OnProgress(int64(id*updatesPerGoroutine+j), j)
info.Emit(taskevent.Event{
TaskID: "concurrent-progress",
Phase: taskevent.PhaseProgress,
DownloadedBytes: int64(id*updatesPerGoroutine + j),
TotalBytes: 10000,
})
}
}(i)
}
wg.Wait()
info := tracker.GetInfo()
if info.Status != TaskStatusRunning {
t.Errorf("expected status Running after concurrent updates, got %s", info.Status)
status, _, downloaded, _, _, _, _, _ := info.snapshot()
if status != TaskStatusRunning {
t.Errorf("expected status Running after concurrent updates, got %s", status)
}
if downloaded <= 0 {
t.Errorf("expected downloaded bytes > 0 after concurrent updates, got %d", downloaded)
}
// Note: Due to race conditions in the simple implementation,
// we can't reliably check exact values without proper synchronization
}
// TestTaskFactoryValidation tests TaskFactory parameter validation
@@ -526,8 +533,7 @@ func TestEdgeCases(t *testing.T) {
{
name: "Progress tracker with empty webhook",
fn: func(t *testing.T) {
tracker := NewProgressTracker("test", "type", "storage", "path", "title", "")
info := tracker.GetInfo()
info := RegisterTask("test-empty-webhook", "type", "storage", "path", "title", "")
if info.Webhook != "" {
t.Error("expected empty webhook")
}

View File

@@ -2,39 +2,48 @@ package api
import (
"sync"
"sync/atomic"
"time"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
)
// TaskProgressInfo 存储任务的进度信息
// TaskProgressInfo stores the progress of an API-submitted task. All fields are
// guarded by mu. It implements taskevent.Sink so the task layer can update it
// without knowing about the API.
type TaskProgressInfo struct {
TaskID string
Type string
Status TaskStatus
Title string
TotalBytes int64
DownloadedBytes int64
TotalFiles int
DownloadedFiles int
Storage string
Path string
Error string
CreatedAt time.Time
UpdatedAt time.Time
Webhook string
mu sync.Mutex
TaskID string
Type string
Status TaskStatus
Title string
TotalBytes int64
DownloadedBytes int64
TotalFiles int
DownloadedFiles int
Storage string
Path string
Error string
CreatedAt time.Time
UpdatedAt time.Time
StartedAt time.Time
Webhook string
webhookNotified bool
}
// progressStore 存储所有 API 任务的进度信息
// progressStore holds all API tasks. Entries are removed a fixed duration after
// they reach a terminal state to bound memory usage.
type progressStore struct {
mu sync.RWMutex
tasks map[string]*TaskProgressInfo
mu sync.RWMutex
tasks map[string]*TaskProgressInfo
retention time.Duration
}
var store = &progressStore{
tasks: make(map[string]*TaskProgressInfo),
tasks: make(map[string]*TaskProgressInfo),
retention: 24 * time.Hour,
}
// RegisterTask 注册一个新的 API 任务
// RegisterTask registers a new API task and returns its progress info.
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
info := &TaskProgressInfo{
TaskID: taskID,
@@ -55,7 +64,7 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP
return info
}
// GetTask 获取任务进度信息
// GetTask returns the progress info for a task.
func GetTask(taskID string) (*TaskProgressInfo, bool) {
store.mu.RLock()
defer store.mu.RUnlock()
@@ -63,7 +72,7 @@ func GetTask(taskID string) (*TaskProgressInfo, bool) {
return info, ok
}
// GetAllTasks 获取所有任务
// GetAllTasks returns all tracked tasks.
func GetAllTasks() []*TaskProgressInfo {
store.mu.RLock()
defer store.mu.RUnlock()
@@ -75,76 +84,133 @@ func GetAllTasks() []*TaskProgressInfo {
return tasks
}
// DeleteTask 删除任务记录
// DeleteTask removes a task record.
func DeleteTask(taskID string) {
store.mu.Lock()
defer store.mu.Unlock()
delete(store.tasks, taskID)
}
// UpdateStatus 更新任务状态
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
t.Status = status
t.UpdatedAt = time.Now()
// CleanupExpired removes tasks that reached a terminal state more than the
// store's retention duration ago. It is safe to call periodically.
func CleanupExpired() {
now := time.Now()
store.mu.Lock()
defer store.mu.Unlock()
for id, info := range store.tasks {
info.mu.Lock()
terminal := info.Status == TaskStatusCompleted || info.Status == TaskStatusFailed || info.Status == TaskStatusCancelled
stale := terminal && now.Sub(info.UpdatedAt) > store.retention
info.mu.Unlock()
if stale {
delete(store.tasks, id)
}
}
}
// SetError 设置错误信息
// StartCleanupLoop runs CleanupExpired on a fixed interval until ctx is done.
// It should be started once during API server initialization.
func StartCleanupLoop(ctx interface{ Done() <-chan struct{} }) {
go func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
CleanupExpired()
}
}
}()
}
// UpdateStatus sets the task status.
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
t.mu.Lock()
t.Status = status
t.UpdatedAt = time.Now()
if status == TaskStatusRunning && t.StartedAt.IsZero() {
t.StartedAt = t.UpdatedAt
}
t.mu.Unlock()
}
// SetError marks the task failed with an error message.
func (t *TaskProgressInfo) SetError(err string) {
t.mu.Lock()
t.Error = err
t.Status = TaskStatusFailed
t.UpdatedAt = time.Now()
t.mu.Unlock()
}
// ProgressTracker 用于 API 任务的进度追踪
type ProgressTracker struct {
info *TaskProgressInfo
// snapshot returns a point-in-time copy of the fields needed to render a
// response, so callers never touch the mutex directly.
func (t *TaskProgressInfo) snapshot() (status TaskStatus, total, downloaded int64, totalFiles, downloadedFiles int, startedAt time.Time, err string, updatedAt time.Time) {
t.mu.Lock()
defer t.mu.Unlock()
return t.Status, t.TotalBytes, t.DownloadedBytes, t.TotalFiles, t.DownloadedFiles, t.StartedAt, t.Error, t.UpdatedAt
}
// NewProgressTracker 创建新的进度追踪器
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
info := RegisterTask(taskID, taskType, storage, path, title, webhook)
return &ProgressTracker{info: info}
}
// OnStart 任务开始
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {
p.info.Status = TaskStatusRunning
p.info.TotalBytes = totalBytes
p.info.TotalFiles = totalFiles
p.info.UpdatedAt = time.Now()
}
// OnProgress 进度更新
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {
atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes)
p.info.DownloadedFiles = downloadedFiles
p.info.UpdatedAt = time.Now()
}
// OnDone 任务完成
func (p *ProgressTracker) OnDone(err error) {
if err != nil {
p.info.Status = TaskStatusFailed
p.info.Error = err.Error()
} else {
p.info.Status = TaskStatusCompleted
// Emit implements taskevent.Sink. It translates task lifecycle events into
// status/progress updates and fires the webhook on terminal transitions.
func (t *TaskProgressInfo) Emit(e taskevent.Event) {
t.mu.Lock()
switch e.Phase {
case taskevent.PhaseStart:
t.Status = TaskStatusRunning
if t.StartedAt.IsZero() {
t.StartedAt = time.Now()
}
if e.TotalBytes > 0 {
t.TotalBytes = e.TotalBytes
}
case taskevent.PhaseProgress:
t.Status = TaskStatusRunning
if e.TotalBytes > 0 {
t.TotalBytes = e.TotalBytes
}
t.DownloadedBytes = e.DownloadedBytes
if e.TotalFiles > 0 {
t.TotalFiles = e.TotalFiles
}
if e.DownloadedFiles > 0 {
t.DownloadedFiles = e.DownloadedFiles
}
case taskevent.PhaseDone:
if e.Err != nil {
t.Status = TaskStatusFailed
t.Error = e.Err.Error()
} else {
t.Status = TaskStatusCompleted
}
}
t.UpdatedAt = time.Now()
notify := t.Webhook != "" && !t.webhookNotified && (t.Status == TaskStatusCompleted || t.Status == TaskStatusFailed)
if notify {
t.webhookNotified = true
}
t.mu.Unlock()
if notify {
payload := CreateWebhookPayload(t.TaskID, t.Type, t.Status, t.Storage, t.Path, e.Err)
SendWebhook(nil, payload)
}
p.info.UpdatedAt = time.Now()
}
// GetInfo 获取任务信息
func (p *ProgressTracker) GetInfo() *TaskProgressInfo {
return p.info
// ProgressTracker is retained for compatibility but is no longer the primary
// progress path; taskevent drives updates now. These methods are safe no-ops
// when called on a nil receiver.
type ProgressTracker struct{}
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
return &ProgressTracker{}
}
// UpdateProgressBytes 更新下载字节数
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {
atomic.StoreInt64(&p.info.DownloadedBytes, bytes)
p.info.UpdatedAt = time.Now()
}
// UpdateProgressFiles 更新下载文件数
func (p *ProgressTracker) UpdateProgressFiles(files int) {
p.info.DownloadedFiles = files
p.info.UpdatedAt = time.Now()
}
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {}
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {}
func (p *ProgressTracker) OnDone(err error) {}
func (p *ProgressTracker) GetInfo() *TaskProgressInfo { return nil }
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {}
func (p *ProgressTracker) UpdateProgressFiles(files int) {}

View File

@@ -57,22 +57,19 @@ func NewServer(ctx context.Context) *Server {
// 404 处理
mux.HandleFunc("/", NotFoundHandler)
// 应用中间件
// Apply middleware chain.
var handler http.Handler = mux
// 添加认证中间件
// Apply auth middleware when a token is configured.
token := cfg.Token
if token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
}
if token != "" {
handler = AuthMiddleware()(handler)
}
// 添加日志中间件
// Add logging middleware.
handler = loggingMiddleware(handler)
// 添加恢复中间件
// Add recovery middleware.
handler = recoveryMiddleware(handler)
return &Server{
@@ -151,7 +148,8 @@ func (rw *responseWriter) WriteHeader(code int) {
rw.ResponseWriter.WriteHeader(code)
}
// Start 初始化并启动 API 服务器
// Start initializes and starts the API server. It refuses to start without a
// token, since an open download proxy is a security risk.
func Start(ctx context.Context) error {
cfg := config.C().API
@@ -160,9 +158,13 @@ func Start(ctx context.Context) error {
}
if cfg.Token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
return fmt.Errorf("API server is enabled but no token is set; refusing to start insecurely")
}
server := NewServer(ctx)
return server.Start(ctx)
if err := server.Start(ctx); err != nil {
return err
}
StartCleanupLoop(ctx)
return nil
}

View File

@@ -40,6 +40,8 @@ type CreateTaskResponse struct {
type TaskProgress struct {
TotalBytes int64 `json:"total_bytes,omitempty"`
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
TotalFiles int `json:"total_files,omitempty"`
DownloadedFiles int `json:"downloaded_files,omitempty"`
Percent float64 `json:"percent,omitempty"`
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
@@ -30,9 +29,14 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
webhookURL := info.Webhook
// 异步发送 webhook
// Async send with retries.
go func() {
logger := log.FromContext(ctx).With("task_id", payload.TaskID)
var logger *log.Logger
if ctx != nil {
logger = log.FromContext(ctx).With("task_id", payload.TaskID)
} else {
logger = log.Default().With("task_id", payload.TaskID)
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
@@ -72,7 +76,7 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
}()
}
// CreateWebhookPayload 创建 Webhook 负载
// CreateWebhookPayload creates a Webhook payload.
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
payload := &WebhookPayload{
TaskID: taskID,
@@ -93,38 +97,3 @@ func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, sto
return payload
}
// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调
func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error {
info, ok := GetTask(taskID)
if !ok {
return fmt.Errorf("task not found: %s", taskID)
}
err := fn()
// 确定任务状态
status := TaskStatusCompleted
if err != nil {
if err == context.Canceled {
status = TaskStatusCancelled
} else {
status = TaskStatusFailed
}
}
// 更新任务状态
if err != nil {
info.SetError(err.Error())
} else {
info.UpdateStatus(TaskStatusCompleted)
}
// 发送 webhook
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}

View File

@@ -1,58 +0,0 @@
package api
import (
"context"
"errors"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
type ExecutableWrapper struct {
inner core.Executable
}
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
return &ExecutableWrapper{inner: inner}
}
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
taskID := w.inner.TaskID()
if info, ok := GetTask(taskID); ok {
info.UpdateStatus(TaskStatusRunning)
}
err := w.inner.Execute(ctx)
info, ok := GetTask(taskID)
if !ok {
return err
}
var status TaskStatus
if err != nil {
if errors.Is(err, context.Canceled) {
status = TaskStatusCancelled
info.UpdateStatus(TaskStatusCancelled)
} else {
status = TaskStatusFailed
info.SetError(err.Error())
}
} else {
status = TaskStatusCompleted
info.UpdateStatus(TaskStatusCompleted)
}
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}