diff --git a/api/factory.go b/api/factory.go index 7020173..4c00a6b 100644 --- a/api/factory.go +++ b/api/factory.go @@ -66,6 +66,19 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e } } +func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error { + taskID := task.TaskID() + RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook) + + err := core.AddTask(f.ctx, NewExecutableWrapper(task)) + if err != nil { + DeleteTask(taskID) + return fmt.Errorf("failed to add task: %w", err) + } + + return nil +} + // createDirectLinksTask 创建直链下载任务 func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { var params DirectLinksParams @@ -79,8 +92,9 @@ func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil) - if err := core.AddTask(f.ctx, task); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) + err := f.registerAndEnqueueTask(task, tasktype.TaskTypeDirectlinks, req.Storage, req.Path, req.Webhook) + if err != nil { + return nil, err } return &CreateTaskResponse{ @@ -104,8 +118,9 @@ func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *C task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil) - if err := core.AddTask(f.ctx, task); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) + err := f.registerAndEnqueueTask(task, tasktype.TaskTypeYtdlp, req.Storage, req.Path, req.Webhook) + if err != nil { + return nil, err } return &CreateTaskResponse{ @@ -146,8 +161,9 @@ func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *C task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil) - if err := core.AddTask(f.ctx, task); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) + err = f.registerAndEnqueueTask(task, tasktype.TaskTypeAria2, req.Storage, req.Path, req.Webhook) + if err != nil { + return nil, err } return &CreateTaskResponse{ @@ -190,8 +206,9 @@ func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req * task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil) - if err := core.AddTask(f.ctx, task); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) + err = f.registerAndEnqueueTask(task, tasktype.TaskTypeParseditem, req.Storage, req.Path, req.Webhook) + if err != nil { + return nil, err } return &CreateTaskResponse{ @@ -223,15 +240,15 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req return nil, fmt.Errorf("no files found in provided links") } + var task core.Executable + if len(files) == 1 { // 单个文件任务 tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil) if err != nil { return nil, fmt.Errorf("failed to create tfile task: %w", err) } - if err := core.AddTask(f.ctx, tfileTask); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) - } + task = tfileTask } else { // 批量文件任务 elems := make([]batchtfile.TaskElement, 0, len(files)) @@ -243,10 +260,12 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req elems = append(elems, *elem) } - task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true) - if err := core.AddTask(f.ctx, task); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) - } + task = batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true) + } + + err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTgfiles, req.Storage, req.Path, req.Webhook) + if err != nil { + return nil, err } return &CreateTaskResponse{ @@ -281,8 +300,9 @@ func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req client := telegraph.NewClient() task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil) - if err := core.AddTask(f.ctx, task); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) + err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTphpics, req.Storage, req.Path, req.Webhook) + if err != nil { + return nil, err } return &CreateTaskResponse{ @@ -342,8 +362,9 @@ func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true) - if err := core.AddTask(f.ctx, task); err != nil { - return nil, fmt.Errorf("failed to add task: %w", err) + err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTransfer, params.TargetStorage, params.TargetPath, req.Webhook) + if err != nil { + return nil, err } return &CreateTaskResponse{ diff --git a/api/server.go b/api/server.go index 5e65e29..be8b660 100644 --- a/api/server.go +++ b/api/server.go @@ -30,16 +30,21 @@ func NewServer(ctx context.Context) *Server { mux.HandleFunc("/health", handlers.HealthCheckHandler) // API v1 路由 - mux.HandleFunc("/api/v1/tasks", handlers.CreateTaskHandler) + mux.HandleFunc("/api/v1/tasks", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + handlers.ListTasksHandler(w, r) + case http.MethodPost: + handlers.CreateTaskHandler(w, r) + default: + MethodNotAllowedHandler(w, r) + } + }) mux.HandleFunc("/api/v1/tasks/", func(w http.ResponseWriter, r *http.Request) { // 根据方法和路径分发 switch r.Method { case http.MethodGet: - if r.URL.Path == "/api/v1/tasks" { - handlers.ListTasksHandler(w, r) - } else { - handlers.GetTaskHandler(w, r) - } + handlers.GetTaskHandler(w, r) case http.MethodDelete: handlers.CancelTaskHandler(w, r) default: diff --git a/api/wrapper.go b/api/wrapper.go new file mode 100644 index 0000000..a3bb18f --- /dev/null +++ b/api/wrapper.go @@ -0,0 +1,58 @@ +package api + +import ( + "context" + "errors" + + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" +) + +// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks. +type ExecutableWrapper struct { + inner core.Executable +} + +func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper { + return &ExecutableWrapper{inner: inner} +} + +func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() } +func (w *ExecutableWrapper) Title() string { return w.inner.Title() } +func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() } + +func (w *ExecutableWrapper) Execute(ctx context.Context) error { + taskID := w.inner.TaskID() + + if info, ok := GetTask(taskID); ok { + info.UpdateStatus(TaskStatusRunning) + } + + err := w.inner.Execute(ctx) + + info, ok := GetTask(taskID) + if !ok { + return err + } + + var status TaskStatus + if err != nil { + if errors.Is(err, context.Canceled) { + status = TaskStatusCancelled + info.UpdateStatus(TaskStatusCancelled) + } else { + status = TaskStatusFailed + info.SetError(err.Error()) + } + } else { + status = TaskStatusCompleted + info.UpdateStatus(TaskStatusCompleted) + } + + if info.Webhook != "" { + payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err) + SendWebhook(ctx, payload) + } + + return err +}