fix(api): sync task lifecycle state and restore GET /api/v1/tasks (#216)

* fix(api): update task route to handle GET and POST methods

Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com>

* fix(api): implement ExecutableWrapper to manage task execution and status updates

Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com>

* fix(api): refactor task registration and enqueueing into a separate method

Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com>

---------

Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com>
This commit is contained in:
Ilham Syahid S
2026-05-24 22:42:16 +07:00
committed by GitHub
parent bfab4c85c8
commit 8059e27978
3 changed files with 109 additions and 25 deletions

View File

@@ -66,6 +66,19 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
}
}
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
taskID := task.TaskID()
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
err := core.AddTask(f.ctx, NewExecutableWrapper(task))
if err != nil {
DeleteTask(taskID)
return fmt.Errorf("failed to add task: %w", err)
}
return nil
}
// createDirectLinksTask 创建直链下载任务
func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params DirectLinksParams
@@ -79,8 +92,9 @@ func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time,
task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
err := f.registerAndEnqueueTask(task, tasktype.TaskTypeDirectlinks, req.Storage, req.Path, req.Webhook)
if err != nil {
return nil, err
}
return &CreateTaskResponse{
@@ -104,8 +118,9 @@ func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *C
task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
err := f.registerAndEnqueueTask(task, tasktype.TaskTypeYtdlp, req.Storage, req.Path, req.Webhook)
if err != nil {
return nil, err
}
return &CreateTaskResponse{
@@ -146,8 +161,9 @@ func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *C
task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeAria2, req.Storage, req.Path, req.Webhook)
if err != nil {
return nil, err
}
return &CreateTaskResponse{
@@ -190,8 +206,9 @@ func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req *
task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeParseditem, req.Storage, req.Path, req.Webhook)
if err != nil {
return nil, err
}
return &CreateTaskResponse{
@@ -223,15 +240,15 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req
return nil, fmt.Errorf("no files found in provided links")
}
var task core.Executable
if len(files) == 1 {
// 单个文件任务
tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil)
if err != nil {
return nil, fmt.Errorf("failed to create tfile task: %w", err)
}
if err := core.AddTask(f.ctx, tfileTask); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
task = tfileTask
} else {
// 批量文件任务
elems := make([]batchtfile.TaskElement, 0, len(files))
@@ -243,10 +260,12 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req
elems = append(elems, *elem)
}
task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
task = batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
}
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTgfiles, req.Storage, req.Path, req.Webhook)
if err != nil {
return nil, err
}
return &CreateTaskResponse{
@@ -281,8 +300,9 @@ func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req
client := telegraph.NewClient()
task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTphpics, req.Storage, req.Path, req.Webhook)
if err != nil {
return nil, err
}
return &CreateTaskResponse{
@@ -342,8 +362,9 @@ func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req
task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTransfer, params.TargetStorage, params.TargetPath, req.Webhook)
if err != nil {
return nil, err
}
return &CreateTaskResponse{

View File

@@ -30,16 +30,21 @@ func NewServer(ctx context.Context) *Server {
mux.HandleFunc("/health", handlers.HealthCheckHandler)
// API v1 路由
mux.HandleFunc("/api/v1/tasks", handlers.CreateTaskHandler)
mux.HandleFunc("/api/v1/tasks", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
handlers.ListTasksHandler(w, r)
case http.MethodPost:
handlers.CreateTaskHandler(w, r)
default:
MethodNotAllowedHandler(w, r)
}
})
mux.HandleFunc("/api/v1/tasks/", func(w http.ResponseWriter, r *http.Request) {
// 根据方法和路径分发
switch r.Method {
case http.MethodGet:
if r.URL.Path == "/api/v1/tasks" {
handlers.ListTasksHandler(w, r)
} else {
handlers.GetTaskHandler(w, r)
}
handlers.GetTaskHandler(w, r)
case http.MethodDelete:
handlers.CancelTaskHandler(w, r)
default:

58
api/wrapper.go Normal file
View File

@@ -0,0 +1,58 @@
package api
import (
"context"
"errors"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
type ExecutableWrapper struct {
inner core.Executable
}
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
return &ExecutableWrapper{inner: inner}
}
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
taskID := w.inner.TaskID()
if info, ok := GetTask(taskID); ok {
info.UpdateStatus(TaskStatusRunning)
}
err := w.inner.Execute(ctx)
info, ok := GetTask(taskID)
if !ok {
return err
}
var status TaskStatus
if err != nil {
if errors.Is(err, context.Canceled) {
status = TaskStatusCancelled
info.UpdateStatus(TaskStatusCancelled)
} else {
status = TaskStatusFailed
info.SetError(err.Error())
}
} else {
status = TaskStatusCompleted
info.UpdateStatus(TaskStatusCompleted)
}
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}