mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-26 17:51:32 +08:00
feat: implement task event system for progress tracking and reporting (#220)
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/parser"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/rs/xid"
|
||||
@@ -68,9 +69,14 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
|
||||
|
||||
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
|
||||
taskID := task.TaskID()
|
||||
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
|
||||
info := RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
|
||||
|
||||
err := core.AddTask(f.ctx, NewExecutableWrapper(task))
|
||||
// Inject the progress sink into the context so the task's Emit calls update
|
||||
// the API store (and fire the webhook on terminal states) without the task
|
||||
// knowing about the API.
|
||||
taskCtx := taskevent.WithSink(f.ctx, info)
|
||||
|
||||
err := core.AddTask(taskCtx, task)
|
||||
if err != nil {
|
||||
DeleteTask(taskID)
|
||||
return fmt.Errorf("failed to add task: %w", err)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
@@ -117,7 +118,7 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 取消任务
|
||||
// Cancel the task; the terminal status is set via the task event stream.
|
||||
if err := core.CancelTask(r.Context(), taskID); err != nil {
|
||||
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
|
||||
return
|
||||
@@ -184,27 +185,45 @@ func extractTaskIDFromPath(path string) string {
|
||||
return parts[3]
|
||||
}
|
||||
|
||||
// convertTaskProgressToResponse 将任务进度转换为响应格式
|
||||
// convertTaskProgressToResponse renders a task's current state, computing
|
||||
// percent and speed from the snapshot taken under the task's mutex.
|
||||
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
|
||||
status, total, downloaded, totalFiles, downloadedFiles, startedAt, errMsg, updatedAt := task.snapshot()
|
||||
|
||||
resp := TaskInfoResponse{
|
||||
TaskID: task.TaskID,
|
||||
Type: tasktype.TaskType(task.Type),
|
||||
Status: task.Status,
|
||||
Status: status,
|
||||
Title: task.Title,
|
||||
Storage: task.Storage,
|
||||
Path: task.Path,
|
||||
Error: task.Error,
|
||||
Error: errMsg,
|
||||
CreatedAt: task.CreatedAt,
|
||||
UpdatedAt: task.UpdatedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
// 计算进度
|
||||
if task.TotalBytes > 0 {
|
||||
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
|
||||
var percent float64
|
||||
var speedMBPS float64
|
||||
if total > 0 {
|
||||
percent = float64(downloaded) * 100 / float64(total)
|
||||
} else if totalFiles > 0 {
|
||||
percent = float64(downloadedFiles) * 100 / float64(totalFiles)
|
||||
}
|
||||
if !startedAt.IsZero() {
|
||||
elapsed := time.Since(startedAt).Seconds()
|
||||
if elapsed > 0 {
|
||||
speedMBPS = float64(downloaded) / elapsed / (1024 * 1024)
|
||||
}
|
||||
}
|
||||
|
||||
if total > 0 || totalFiles > 0 {
|
||||
resp.Progress = &TaskProgress{
|
||||
TotalBytes: task.TotalBytes,
|
||||
DownloadedBytes: task.DownloadedBytes,
|
||||
TotalBytes: total,
|
||||
DownloadedBytes: downloaded,
|
||||
TotalFiles: totalFiles,
|
||||
DownloadedFiles: downloadedFiles,
|
||||
Percent: percent,
|
||||
SpeedMBPS: speedMBPS,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
// setupTestServer creates a test server with handlers
|
||||
@@ -403,32 +404,38 @@ func TestConcurrentProgressStore(t *testing.T) {
|
||||
|
||||
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates
|
||||
func TestProgressTrackerConcurrentUpdates(t *testing.T) {
|
||||
tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
|
||||
tracker.OnStart(10000, 10)
|
||||
info := RegisterTask("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
|
||||
info.Emit(taskevent.Event{TaskID: "concurrent-progress", Phase: taskevent.PhaseStart, TotalBytes: 10000})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
updatesPerGoroutine := 100
|
||||
|
||||
// Concurrent progress updates
|
||||
// Concurrent progress updates via the Sink interface
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := range updatesPerGoroutine {
|
||||
tracker.OnProgress(int64(id*updatesPerGoroutine+j), j)
|
||||
info.Emit(taskevent.Event{
|
||||
TaskID: "concurrent-progress",
|
||||
Phase: taskevent.PhaseProgress,
|
||||
DownloadedBytes: int64(id*updatesPerGoroutine + j),
|
||||
TotalBytes: 10000,
|
||||
})
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
info := tracker.GetInfo()
|
||||
if info.Status != TaskStatusRunning {
|
||||
t.Errorf("expected status Running after concurrent updates, got %s", info.Status)
|
||||
status, _, downloaded, _, _, _, _, _ := info.snapshot()
|
||||
if status != TaskStatusRunning {
|
||||
t.Errorf("expected status Running after concurrent updates, got %s", status)
|
||||
}
|
||||
if downloaded <= 0 {
|
||||
t.Errorf("expected downloaded bytes > 0 after concurrent updates, got %d", downloaded)
|
||||
}
|
||||
// Note: Due to race conditions in the simple implementation,
|
||||
// we can't reliably check exact values without proper synchronization
|
||||
}
|
||||
|
||||
// TestTaskFactoryValidation tests TaskFactory parameter validation
|
||||
@@ -526,8 +533,7 @@ func TestEdgeCases(t *testing.T) {
|
||||
{
|
||||
name: "Progress tracker with empty webhook",
|
||||
fn: func(t *testing.T) {
|
||||
tracker := NewProgressTracker("test", "type", "storage", "path", "title", "")
|
||||
info := tracker.GetInfo()
|
||||
info := RegisterTask("test-empty-webhook", "type", "storage", "path", "title", "")
|
||||
if info.Webhook != "" {
|
||||
t.Error("expected empty webhook")
|
||||
}
|
||||
|
||||
216
api/progress.go
216
api/progress.go
@@ -2,39 +2,48 @@ package api
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
// TaskProgressInfo 存储任务的进度信息
|
||||
// TaskProgressInfo stores the progress of an API-submitted task. All fields are
|
||||
// guarded by mu. It implements taskevent.Sink so the task layer can update it
|
||||
// without knowing about the API.
|
||||
type TaskProgressInfo struct {
|
||||
TaskID string
|
||||
Type string
|
||||
Status TaskStatus
|
||||
Title string
|
||||
TotalBytes int64
|
||||
DownloadedBytes int64
|
||||
TotalFiles int
|
||||
DownloadedFiles int
|
||||
Storage string
|
||||
Path string
|
||||
Error string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Webhook string
|
||||
mu sync.Mutex
|
||||
TaskID string
|
||||
Type string
|
||||
Status TaskStatus
|
||||
Title string
|
||||
TotalBytes int64
|
||||
DownloadedBytes int64
|
||||
TotalFiles int
|
||||
DownloadedFiles int
|
||||
Storage string
|
||||
Path string
|
||||
Error string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
StartedAt time.Time
|
||||
Webhook string
|
||||
webhookNotified bool
|
||||
}
|
||||
|
||||
// progressStore 存储所有 API 任务的进度信息
|
||||
// progressStore holds all API tasks. Entries are removed a fixed duration after
|
||||
// they reach a terminal state to bound memory usage.
|
||||
type progressStore struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*TaskProgressInfo
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*TaskProgressInfo
|
||||
retention time.Duration
|
||||
}
|
||||
|
||||
var store = &progressStore{
|
||||
tasks: make(map[string]*TaskProgressInfo),
|
||||
tasks: make(map[string]*TaskProgressInfo),
|
||||
retention: 24 * time.Hour,
|
||||
}
|
||||
|
||||
// RegisterTask 注册一个新的 API 任务
|
||||
// RegisterTask registers a new API task and returns its progress info.
|
||||
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
|
||||
info := &TaskProgressInfo{
|
||||
TaskID: taskID,
|
||||
@@ -55,7 +64,7 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP
|
||||
return info
|
||||
}
|
||||
|
||||
// GetTask 获取任务进度信息
|
||||
// GetTask returns the progress info for a task.
|
||||
func GetTask(taskID string) (*TaskProgressInfo, bool) {
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
@@ -63,7 +72,7 @@ func GetTask(taskID string) (*TaskProgressInfo, bool) {
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// GetAllTasks 获取所有任务
|
||||
// GetAllTasks returns all tracked tasks.
|
||||
func GetAllTasks() []*TaskProgressInfo {
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
@@ -75,76 +84,133 @@ func GetAllTasks() []*TaskProgressInfo {
|
||||
return tasks
|
||||
}
|
||||
|
||||
// DeleteTask 删除任务记录
|
||||
// DeleteTask removes a task record.
|
||||
func DeleteTask(taskID string) {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
delete(store.tasks, taskID)
|
||||
}
|
||||
|
||||
// UpdateStatus 更新任务状态
|
||||
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
|
||||
t.Status = status
|
||||
t.UpdatedAt = time.Now()
|
||||
// CleanupExpired removes tasks that reached a terminal state more than the
|
||||
// store's retention duration ago. It is safe to call periodically.
|
||||
func CleanupExpired() {
|
||||
now := time.Now()
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
for id, info := range store.tasks {
|
||||
info.mu.Lock()
|
||||
terminal := info.Status == TaskStatusCompleted || info.Status == TaskStatusFailed || info.Status == TaskStatusCancelled
|
||||
stale := terminal && now.Sub(info.UpdatedAt) > store.retention
|
||||
info.mu.Unlock()
|
||||
if stale {
|
||||
delete(store.tasks, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetError 设置错误信息
|
||||
// StartCleanupLoop runs CleanupExpired on a fixed interval until ctx is done.
|
||||
// It should be started once during API server initialization.
|
||||
func StartCleanupLoop(ctx interface{ Done() <-chan struct{} }) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
CleanupExpired()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// UpdateStatus sets the task status.
|
||||
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
|
||||
t.mu.Lock()
|
||||
t.Status = status
|
||||
t.UpdatedAt = time.Now()
|
||||
if status == TaskStatusRunning && t.StartedAt.IsZero() {
|
||||
t.StartedAt = t.UpdatedAt
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetError marks the task failed with an error message.
|
||||
func (t *TaskProgressInfo) SetError(err string) {
|
||||
t.mu.Lock()
|
||||
t.Error = err
|
||||
t.Status = TaskStatusFailed
|
||||
t.UpdatedAt = time.Now()
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// ProgressTracker 用于 API 任务的进度追踪
|
||||
type ProgressTracker struct {
|
||||
info *TaskProgressInfo
|
||||
// snapshot returns a point-in-time copy of the fields needed to render a
|
||||
// response, so callers never touch the mutex directly.
|
||||
func (t *TaskProgressInfo) snapshot() (status TaskStatus, total, downloaded int64, totalFiles, downloadedFiles int, startedAt time.Time, err string, updatedAt time.Time) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.Status, t.TotalBytes, t.DownloadedBytes, t.TotalFiles, t.DownloadedFiles, t.StartedAt, t.Error, t.UpdatedAt
|
||||
}
|
||||
|
||||
// NewProgressTracker 创建新的进度追踪器
|
||||
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
|
||||
info := RegisterTask(taskID, taskType, storage, path, title, webhook)
|
||||
return &ProgressTracker{info: info}
|
||||
}
|
||||
|
||||
// OnStart 任务开始
|
||||
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {
|
||||
p.info.Status = TaskStatusRunning
|
||||
p.info.TotalBytes = totalBytes
|
||||
p.info.TotalFiles = totalFiles
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// OnProgress 进度更新
|
||||
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {
|
||||
atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes)
|
||||
p.info.DownloadedFiles = downloadedFiles
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// OnDone 任务完成
|
||||
func (p *ProgressTracker) OnDone(err error) {
|
||||
if err != nil {
|
||||
p.info.Status = TaskStatusFailed
|
||||
p.info.Error = err.Error()
|
||||
} else {
|
||||
p.info.Status = TaskStatusCompleted
|
||||
// Emit implements taskevent.Sink. It translates task lifecycle events into
|
||||
// status/progress updates and fires the webhook on terminal transitions.
|
||||
func (t *TaskProgressInfo) Emit(e taskevent.Event) {
|
||||
t.mu.Lock()
|
||||
switch e.Phase {
|
||||
case taskevent.PhaseStart:
|
||||
t.Status = TaskStatusRunning
|
||||
if t.StartedAt.IsZero() {
|
||||
t.StartedAt = time.Now()
|
||||
}
|
||||
if e.TotalBytes > 0 {
|
||||
t.TotalBytes = e.TotalBytes
|
||||
}
|
||||
case taskevent.PhaseProgress:
|
||||
t.Status = TaskStatusRunning
|
||||
if e.TotalBytes > 0 {
|
||||
t.TotalBytes = e.TotalBytes
|
||||
}
|
||||
t.DownloadedBytes = e.DownloadedBytes
|
||||
if e.TotalFiles > 0 {
|
||||
t.TotalFiles = e.TotalFiles
|
||||
}
|
||||
if e.DownloadedFiles > 0 {
|
||||
t.DownloadedFiles = e.DownloadedFiles
|
||||
}
|
||||
case taskevent.PhaseDone:
|
||||
if e.Err != nil {
|
||||
t.Status = TaskStatusFailed
|
||||
t.Error = e.Err.Error()
|
||||
} else {
|
||||
t.Status = TaskStatusCompleted
|
||||
}
|
||||
}
|
||||
t.UpdatedAt = time.Now()
|
||||
notify := t.Webhook != "" && !t.webhookNotified && (t.Status == TaskStatusCompleted || t.Status == TaskStatusFailed)
|
||||
if notify {
|
||||
t.webhookNotified = true
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
if notify {
|
||||
payload := CreateWebhookPayload(t.TaskID, t.Type, t.Status, t.Storage, t.Path, e.Err)
|
||||
SendWebhook(nil, payload)
|
||||
}
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// GetInfo 获取任务信息
|
||||
func (p *ProgressTracker) GetInfo() *TaskProgressInfo {
|
||||
return p.info
|
||||
// ProgressTracker is retained for compatibility but is no longer the primary
|
||||
// progress path; taskevent drives updates now. These methods are safe no-ops
|
||||
// when called on a nil receiver.
|
||||
type ProgressTracker struct{}
|
||||
|
||||
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
|
||||
return &ProgressTracker{}
|
||||
}
|
||||
|
||||
// UpdateProgressBytes 更新下载字节数
|
||||
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {
|
||||
atomic.StoreInt64(&p.info.DownloadedBytes, bytes)
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// UpdateProgressFiles 更新下载文件数
|
||||
func (p *ProgressTracker) UpdateProgressFiles(files int) {
|
||||
p.info.DownloadedFiles = files
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {}
|
||||
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {}
|
||||
func (p *ProgressTracker) OnDone(err error) {}
|
||||
func (p *ProgressTracker) GetInfo() *TaskProgressInfo { return nil }
|
||||
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {}
|
||||
func (p *ProgressTracker) UpdateProgressFiles(files int) {}
|
||||
|
||||
@@ -57,22 +57,19 @@ func NewServer(ctx context.Context) *Server {
|
||||
// 404 处理
|
||||
mux.HandleFunc("/", NotFoundHandler)
|
||||
|
||||
// 应用中间件
|
||||
// Apply middleware chain.
|
||||
var handler http.Handler = mux
|
||||
|
||||
// 添加认证中间件
|
||||
// Apply auth middleware when a token is configured.
|
||||
token := cfg.Token
|
||||
if token == "" {
|
||||
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
|
||||
}
|
||||
if token != "" {
|
||||
handler = AuthMiddleware()(handler)
|
||||
}
|
||||
|
||||
// 添加日志中间件
|
||||
// Add logging middleware.
|
||||
handler = loggingMiddleware(handler)
|
||||
|
||||
// 添加恢复中间件
|
||||
// Add recovery middleware.
|
||||
handler = recoveryMiddleware(handler)
|
||||
|
||||
return &Server{
|
||||
@@ -151,7 +148,8 @@ func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Start 初始化并启动 API 服务器
|
||||
// Start initializes and starts the API server. It refuses to start without a
|
||||
// token, since an open download proxy is a security risk.
|
||||
func Start(ctx context.Context) error {
|
||||
cfg := config.C().API
|
||||
|
||||
@@ -160,9 +158,13 @@ func Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if cfg.Token == "" {
|
||||
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
|
||||
return fmt.Errorf("API server is enabled but no token is set; refusing to start insecurely")
|
||||
}
|
||||
|
||||
server := NewServer(ctx)
|
||||
return server.Start(ctx)
|
||||
if err := server.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
StartCleanupLoop(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ type CreateTaskResponse struct {
|
||||
type TaskProgress struct {
|
||||
TotalBytes int64 `json:"total_bytes,omitempty"`
|
||||
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
|
||||
TotalFiles int `json:"total_files,omitempty"`
|
||||
DownloadedFiles int `json:"downloaded_files,omitempty"`
|
||||
Percent float64 `json:"percent,omitempty"`
|
||||
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -30,9 +29,14 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
|
||||
|
||||
webhookURL := info.Webhook
|
||||
|
||||
// 异步发送 webhook
|
||||
// Async send with retries.
|
||||
go func() {
|
||||
logger := log.FromContext(ctx).With("task_id", payload.TaskID)
|
||||
var logger *log.Logger
|
||||
if ctx != nil {
|
||||
logger = log.FromContext(ctx).With("task_id", payload.TaskID)
|
||||
} else {
|
||||
logger = log.Default().With("task_id", payload.TaskID)
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -72,7 +76,7 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
|
||||
}()
|
||||
}
|
||||
|
||||
// CreateWebhookPayload 创建 Webhook 负载
|
||||
// CreateWebhookPayload creates a Webhook payload.
|
||||
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
|
||||
payload := &WebhookPayload{
|
||||
TaskID: taskID,
|
||||
@@ -93,38 +97,3 @@ func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, sto
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调
|
||||
func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error {
|
||||
info, ok := GetTask(taskID)
|
||||
if !ok {
|
||||
return fmt.Errorf("task not found: %s", taskID)
|
||||
}
|
||||
|
||||
err := fn()
|
||||
|
||||
// 确定任务状态
|
||||
status := TaskStatusCompleted
|
||||
if err != nil {
|
||||
if err == context.Canceled {
|
||||
status = TaskStatusCancelled
|
||||
} else {
|
||||
status = TaskStatusFailed
|
||||
}
|
||||
}
|
||||
|
||||
// 更新任务状态
|
||||
if err != nil {
|
||||
info.SetError(err.Error())
|
||||
} else {
|
||||
info.UpdateStatus(TaskStatusCompleted)
|
||||
}
|
||||
|
||||
// 发送 webhook
|
||||
if info.Webhook != "" {
|
||||
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
|
||||
SendWebhook(ctx, payload)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
)
|
||||
|
||||
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
|
||||
type ExecutableWrapper struct {
|
||||
inner core.Executable
|
||||
}
|
||||
|
||||
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
|
||||
return &ExecutableWrapper{inner: inner}
|
||||
}
|
||||
|
||||
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
|
||||
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
|
||||
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
|
||||
|
||||
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
|
||||
taskID := w.inner.TaskID()
|
||||
|
||||
if info, ok := GetTask(taskID); ok {
|
||||
info.UpdateStatus(TaskStatusRunning)
|
||||
}
|
||||
|
||||
err := w.inner.Execute(ctx)
|
||||
|
||||
info, ok := GetTask(taskID)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
var status TaskStatus
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
status = TaskStatusCancelled
|
||||
info.UpdateStatus(TaskStatusCancelled)
|
||||
} else {
|
||||
status = TaskStatusFailed
|
||||
info.SetError(err.Error())
|
||||
}
|
||||
} else {
|
||||
status = TaskStatusCompleted
|
||||
info.UpdateStatus(TaskStatusCompleted)
|
||||
}
|
||||
|
||||
if info.Webhook != "" {
|
||||
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
|
||||
SendWebhook(ctx, payload)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/queue"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
var queueInstance *queue.TaskQueue[Executable]
|
||||
@@ -30,11 +31,14 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
|
||||
break // queue closed and empty
|
||||
}
|
||||
exe := qtask.Data
|
||||
taskCtx := qtask.Context()
|
||||
logger.Infof("Processing task: %s", exe.TaskID())
|
||||
if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil {
|
||||
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseStart})
|
||||
if err := ExecCommandString(taskCtx, execHooks.TaskBeforeStart); err != nil {
|
||||
logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err)
|
||||
}
|
||||
if err := exe.Execute(qtask.Context()); err != nil {
|
||||
err = exe.Execute(taskCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Infof("Task %s was canceled", exe.TaskID())
|
||||
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
|
||||
@@ -52,6 +56,7 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
|
||||
logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err)
|
||||
}
|
||||
}
|
||||
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseDone, Err: err})
|
||||
qe.Done(qtask.ID)
|
||||
<-semaphore
|
||||
}
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
// Execute implements core.Executable.
|
||||
@@ -77,6 +79,12 @@ func (t *Task) waitForDownload(ctx context.Context) error {
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t, status)
|
||||
}
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: parseInt64(status.TotalLength),
|
||||
DownloadedBytes: parseInt64(status.CompletedLength),
|
||||
})
|
||||
|
||||
// Check if download is complete
|
||||
if status.IsDownloadComplete() {
|
||||
@@ -248,3 +256,16 @@ func (t *Task) cancelAria2Download() {
|
||||
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseInt64 parses an aria2 status string (decimal bytes) into int64,
|
||||
// returning 0 on failure so it can be used directly in progress events.
|
||||
func parseInt64(s string) int64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
n, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -62,8 +63,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
return elem.Storage.Save(uploadCtx, pr, elem.Path)
|
||||
})
|
||||
wr := ioutil.NewProgressWriter(pw, func(n int) {
|
||||
t.downloaded.Add(int64(n))
|
||||
downloaded := t.downloaded.Add(int64(n))
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalSize,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
errg.Go(func() error {
|
||||
defer pw.Close()
|
||||
@@ -92,8 +99,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
}
|
||||
}()
|
||||
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
|
||||
t.downloaded.Add(int64(n))
|
||||
downloaded := t.downloaded.Add(int64(n))
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalSize,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
_, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -143,10 +144,16 @@ func (t *Task) processLink(ctx context.Context, file *File) error {
|
||||
}
|
||||
}()
|
||||
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
||||
t.downloadedBytes.Add(int64(n))
|
||||
downloaded := t.downloadedBytes.Add(int64(n))
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
}
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalBytes,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
|
||||
copyResultCh := make(chan error, 1)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/parser"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -107,10 +108,16 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er
|
||||
}
|
||||
}()
|
||||
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
||||
t.downloadedBytes.Add(int64(n))
|
||||
downloaded := t.downloadedBytes.Add(int64(n))
|
||||
if t.progress != nil {
|
||||
t.progress.OnProgress(ctx, t)
|
||||
}
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalBytes,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
|
||||
copyResultCh := make(chan error, 1)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/duke-git/lancet/v2/retry"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -27,8 +28,14 @@ func (t *Task) Execute(ctx context.Context) error {
|
||||
logger.Errorf("Error processing picture %s: %v", pic, err)
|
||||
return fmt.Errorf("failed to process picture %s: %w", pic, err)
|
||||
}
|
||||
t.downloaded.Add(1)
|
||||
downloaded := t.downloaded.Add(1)
|
||||
t.progress.OnProgress(gctx, t)
|
||||
taskevent.Emit(gctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalFiles: t.totalpics,
|
||||
DownloadedFiles: int(downloaded),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
type ProgressWriterAt struct {
|
||||
@@ -20,9 +22,16 @@ func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
downloaded := w.downloaded.Add(int64(at))
|
||||
if w.progress != nil {
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
|
||||
}
|
||||
taskevent.Emit(w.ctx, taskevent.Event{
|
||||
TaskID: w.info.TaskID(),
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: w.total,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
return at, nil
|
||||
}
|
||||
|
||||
@@ -56,9 +65,16 @@ func (w *ProgressWriter) Write(p []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
downloaded := w.downloaded.Add(int64(at))
|
||||
if w.progress != nil {
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
|
||||
}
|
||||
taskevent.Emit(w.ctx, taskevent.Event{
|
||||
TaskID: w.info.TaskID(),
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: w.total,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
return at, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
@@ -116,6 +117,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
|
||||
t.uploaded.Add(size)
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalSize,
|
||||
DownloadedBytes: t.uploaded.Load(),
|
||||
})
|
||||
|
||||
logger.Info("File uploaded successfully")
|
||||
return nil
|
||||
|
||||
88
pkg/taskevent/taskevent.go
Normal file
88
pkg/taskevent/taskevent.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Package taskevent provides a decoupled, context-scoped event bus for task
|
||||
// lifecycle progress. Producers (task implementations) emit events via Emit;
|
||||
// consumers (e.g. the API progress store, the Telegram message editor) register
|
||||
// as Sinks and are injected through context. This keeps the task layer free of
|
||||
// any concrete progress-display dependency, so new task types gain progress
|
||||
// reporting for free and new observers can be added without touching tasks.
|
||||
package taskevent
|
||||
|
||||
import "context"
|
||||
|
||||
// Phase marks a stage in a task's lifecycle.
|
||||
type Phase int
|
||||
|
||||
const (
|
||||
PhaseStart Phase = iota
|
||||
PhaseProgress
|
||||
PhaseDone
|
||||
)
|
||||
|
||||
func (p Phase) String() string {
|
||||
switch p {
|
||||
case PhaseStart:
|
||||
return "start"
|
||||
case PhaseProgress:
|
||||
return "progress"
|
||||
case PhaseDone:
|
||||
return "done"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Event describes a single progress observation for a task. Byte fields are
|
||||
// populated by byte-stream tasks; file-count fields by count-based tasks. A
|
||||
// task may fill whichever subset it has; observers ignore zero values.
|
||||
type Event struct {
|
||||
TaskID string
|
||||
Phase Phase
|
||||
TotalBytes int64
|
||||
DownloadedBytes int64
|
||||
TotalFiles int
|
||||
DownloadedFiles int
|
||||
Err error
|
||||
}
|
||||
|
||||
// Sink receives task events. Implementations must be safe for concurrent use.
|
||||
type Sink interface {
|
||||
Emit(Event)
|
||||
}
|
||||
|
||||
// SinkFunc is a function adapter for Sink.
|
||||
type SinkFunc func(Event)
|
||||
|
||||
func (f SinkFunc) Emit(e Event) { f(e) }
|
||||
|
||||
type sinkKey struct{}
|
||||
|
||||
// WithSink returns a ctx carrying the given sinks. Multiple sinks can be passed
|
||||
// and all will receive every emitted event. Sinks already present in ctx are
|
||||
// preserved.
|
||||
func WithSink(ctx context.Context, sinks ...Sink) context.Context {
|
||||
if len(sinks) == 0 {
|
||||
return ctx
|
||||
}
|
||||
var existing []Sink
|
||||
if v, ok := ctx.Value(sinkKey{}).([]Sink); ok {
|
||||
existing = v
|
||||
}
|
||||
merged := make([]Sink, 0, len(existing)+len(sinks))
|
||||
merged = append(merged, existing...)
|
||||
merged = append(merged, sinks...)
|
||||
return context.WithValue(ctx, sinkKey{}, merged)
|
||||
}
|
||||
|
||||
// Emit broadcasts an event to all sinks carried by ctx. It is a no-op when no
|
||||
// sink is attached, so producers can call it unconditionally.
|
||||
func Emit(ctx context.Context, e Event) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
sinks, ok := ctx.Value(sinkKey{}).([]Sink)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, s := range sinks {
|
||||
s.Emit(e)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user