mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-05-29 20:19:57 +08:00
fix(api): sync task lifecycle state and restore GET /api/v1/tasks (#216)
* fix(api): update task route to handle GET and POST methods Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com> * fix(api): implement ExecutableWrapper to manage task execution and status updates Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com> * fix(api): refactor task registration and enqueueing into a separate method Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com> --------- Signed-off-by: Ilham Syahid S <ilhamsyahids@gmail.com>
This commit is contained in:
@@ -66,6 +66,19 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
|
||||
}
|
||||
}
|
||||
|
||||
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
|
||||
taskID := task.TaskID()
|
||||
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
|
||||
|
||||
err := core.AddTask(f.ctx, NewExecutableWrapper(task))
|
||||
if err != nil {
|
||||
DeleteTask(taskID)
|
||||
return fmt.Errorf("failed to add task: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDirectLinksTask 创建直链下载任务
|
||||
func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
|
||||
var params DirectLinksParams
|
||||
@@ -79,8 +92,9 @@ func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time,
|
||||
|
||||
task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil)
|
||||
|
||||
if err := core.AddTask(f.ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
err := f.registerAndEnqueueTask(task, tasktype.TaskTypeDirectlinks, req.Storage, req.Path, req.Webhook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateTaskResponse{
|
||||
@@ -104,8 +118,9 @@ func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *C
|
||||
|
||||
task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil)
|
||||
|
||||
if err := core.AddTask(f.ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
err := f.registerAndEnqueueTask(task, tasktype.TaskTypeYtdlp, req.Storage, req.Path, req.Webhook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateTaskResponse{
|
||||
@@ -146,8 +161,9 @@ func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *C
|
||||
|
||||
task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil)
|
||||
|
||||
if err := core.AddTask(f.ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeAria2, req.Storage, req.Path, req.Webhook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateTaskResponse{
|
||||
@@ -190,8 +206,9 @@ func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req *
|
||||
|
||||
task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil)
|
||||
|
||||
if err := core.AddTask(f.ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeParseditem, req.Storage, req.Path, req.Webhook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateTaskResponse{
|
||||
@@ -223,15 +240,15 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req
|
||||
return nil, fmt.Errorf("no files found in provided links")
|
||||
}
|
||||
|
||||
var task core.Executable
|
||||
|
||||
if len(files) == 1 {
|
||||
// 单个文件任务
|
||||
tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tfile task: %w", err)
|
||||
}
|
||||
if err := core.AddTask(f.ctx, tfileTask); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
}
|
||||
task = tfileTask
|
||||
} else {
|
||||
// 批量文件任务
|
||||
elems := make([]batchtfile.TaskElement, 0, len(files))
|
||||
@@ -243,10 +260,12 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req
|
||||
elems = append(elems, *elem)
|
||||
}
|
||||
|
||||
task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
|
||||
if err := core.AddTask(f.ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
}
|
||||
task = batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
|
||||
}
|
||||
|
||||
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTgfiles, req.Storage, req.Path, req.Webhook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateTaskResponse{
|
||||
@@ -281,8 +300,9 @@ func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req
|
||||
client := telegraph.NewClient()
|
||||
task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil)
|
||||
|
||||
if err := core.AddTask(f.ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTphpics, req.Storage, req.Path, req.Webhook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateTaskResponse{
|
||||
@@ -342,8 +362,9 @@ func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req
|
||||
|
||||
task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true)
|
||||
|
||||
if err := core.AddTask(f.ctx, task); err != nil {
|
||||
return nil, fmt.Errorf("failed to add task: %w", err)
|
||||
err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTransfer, params.TargetStorage, params.TargetPath, req.Webhook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateTaskResponse{
|
||||
|
||||
@@ -30,16 +30,21 @@ func NewServer(ctx context.Context) *Server {
|
||||
mux.HandleFunc("/health", handlers.HealthCheckHandler)
|
||||
|
||||
// API v1 路由
|
||||
mux.HandleFunc("/api/v1/tasks", handlers.CreateTaskHandler)
|
||||
mux.HandleFunc("/api/v1/tasks", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
handlers.ListTasksHandler(w, r)
|
||||
case http.MethodPost:
|
||||
handlers.CreateTaskHandler(w, r)
|
||||
default:
|
||||
MethodNotAllowedHandler(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/api/v1/tasks/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// 根据方法和路径分发
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
if r.URL.Path == "/api/v1/tasks" {
|
||||
handlers.ListTasksHandler(w, r)
|
||||
} else {
|
||||
handlers.GetTaskHandler(w, r)
|
||||
}
|
||||
handlers.GetTaskHandler(w, r)
|
||||
case http.MethodDelete:
|
||||
handlers.CancelTaskHandler(w, r)
|
||||
default:
|
||||
|
||||
58
api/wrapper.go
Normal file
58
api/wrapper.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
)
|
||||
|
||||
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
|
||||
type ExecutableWrapper struct {
|
||||
inner core.Executable
|
||||
}
|
||||
|
||||
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
|
||||
return &ExecutableWrapper{inner: inner}
|
||||
}
|
||||
|
||||
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
|
||||
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
|
||||
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
|
||||
|
||||
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
|
||||
taskID := w.inner.TaskID()
|
||||
|
||||
if info, ok := GetTask(taskID); ok {
|
||||
info.UpdateStatus(TaskStatusRunning)
|
||||
}
|
||||
|
||||
err := w.inner.Execute(ctx)
|
||||
|
||||
info, ok := GetTask(taskID)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
var status TaskStatus
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
status = TaskStatusCancelled
|
||||
info.UpdateStatus(TaskStatusCancelled)
|
||||
} else {
|
||||
status = TaskStatusFailed
|
||||
info.SetError(err.Error())
|
||||
}
|
||||
} else {
|
||||
status = TaskStatusCompleted
|
||||
info.UpdateStatus(TaskStatusCompleted)
|
||||
}
|
||||
|
||||
if info.Webhook != "" {
|
||||
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
|
||||
SendWebhook(ctx, payload)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
Reference in New Issue
Block a user