From 98ba7c50e783d181f456055825c89017d6534156 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Thu, 27 Feb 2025 21:32:14 +0800 Subject: [PATCH 1/4] refactor: remove unused StoragePath initialization in AddToQueue function --- bot/handle_add_task.go | 1 - 1 file changed, 1 deletion(-) diff --git a/bot/handle_add_task.go b/bot/handle_add_task.go index 3b30c51..5804217 100644 --- a/bot/handle_add_task.go +++ b/bot/handle_add_task.go @@ -153,7 +153,6 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { Status: types.Pending, File: file, StorageName: storageName, - StoragePath: path.Join(), FileChatID: record.ChatID, ReplyMessageID: record.ReplyMessageID, FileMessageID: record.MessageID, From be6444cf96d2d07c74b9622fb311857302861cf8 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:02:16 +0800 Subject: [PATCH 2/4] feat: implement task cancellation feature and update task handling --- bot/handle_add_task.go | 2 +- bot/handle_cancel_task.go | 27 ++++++++++++++++++ bot/handlers.go | 1 + bot/utils.go | 2 +- core/core.go | 4 +-- core/download.go | 14 +++++++--- core/utils.go | 13 +++++++-- queue/queue.go | 59 ++++++++++++++++++++++++++++++--------- types/types.go | 5 ++++ 9 files changed, 103 insertions(+), 24 deletions(-) create mode 100644 bot/handle_cancel_task.go diff --git a/bot/handle_add_task.go b/bot/handle_add_task.go index 5804217..167caff 100644 --- a/bot/handle_add_task.go +++ b/bot/handle_add_task.go @@ -163,7 +163,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { task.StoragePath = path.Join(dir.Path, file.FileName) } - queue.AddTask(task) + queue.AddTask(&task) entityBuilder := entity.Builder{} var entities []tg.MessageEntityClass diff --git a/bot/handle_cancel_task.go b/bot/handle_cancel_task.go new file mode 100644 index 0000000..512df00 --- /dev/null +++ b/bot/handle_cancel_task.go @@ -0,0 +1,27 @@ +package bot + +import ( + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/queue" +) + +func cancelTask(ctx *ext.Context, update *ext.Update) error { + key := strings.Split(string(update.CallbackQuery.Data), " ")[1] + ok := queue.CancelTask(key) + if ok { + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Message: "任务已取消", + }) + return dispatcher.EndGroups + } + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Message: "任务取消失败", + }) + return dispatcher.EndGroups +} diff --git a/bot/handlers.go b/bot/handlers.go index 7f5f2c0..3631fa5 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -22,5 +22,6 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) { dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage)) + dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask)) dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage)) } diff --git a/bot/utils.go b/bot/utils.go index 1c42c89..e008e9c 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -264,7 +264,7 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, t }) return dispatcher.EndGroups } - queue.AddTask(*task) + queue.AddTask(task) ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()), ID: task.ReplyMessageID, diff --git a/core/core.go b/core/core.go index c91f043..b2dcc62 100644 --- a/core/core.go +++ b/core/core.go @@ -22,13 +22,13 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) { switch task.Status { case types.Pending: logger.L.Infof("Processing task: %s", task.String()) - if err := processPendingTask(&task); err != nil { - logger.L.Errorf("Failed to do task: %s", err) + if err := processPendingTask(task); err != nil { task.Error = err if errors.Is(err, context.Canceled) { logger.L.Debugf("Task canceled: %s", task.String()) task.Status = types.Canceled } else { + logger.L.Errorf("Failed to do task: %s", err) task.Status = types.Failed } } else { diff --git a/core/download.go b/core/download.go index 960bfd1..e2db1bb 100644 --- a/core/download.go +++ b/core/download.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "path/filepath" "time" @@ -48,11 +49,16 @@ func processPendingTask(task *types.Task) error { return fmt.Errorf("context is not *ext.Context: %T", task.Ctx) } + cancelCtx, cancel := context.WithCancel(ctx) + task.Cancel = cancel + task.Ctx = cancelCtx + text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + ReplyMarkup: getCancelTaskMarkup(task), }) progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) @@ -63,7 +69,7 @@ func processPendingTask(task *types.Task) error { defer dest.Close() task.StartTime = time.Now() downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) - _, err = downloadBuider.Parallel(ctx, dest) + _, err = downloadBuider.Parallel(cancelCtx, dest) if err != nil { return fmt.Errorf("下载文件失败: %w", err) } diff --git a/core/utils.go b/core/utils.go index d785d20..3a069b7 100644 --- a/core/utils.go +++ b/core/utils.go @@ -139,13 +139,20 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) } text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + ReplyMarkup: getCancelTaskMarkup(task), }) } } +func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup { + return &tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}}, + } +} + func fixTaskFileExt(task *types.Task, localFilePath string) { if path.Ext(task.FileName()) == "" { mimeType, err := mimetype.DetectFile(localFilePath) diff --git a/queue/queue.go b/queue/queue.go index 4ac8733..0434c8c 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -8,30 +8,58 @@ import ( ) type TaskQueue struct { - list *list.List - cond *sync.Cond - mutex *sync.Mutex + list *list.List + cond *sync.Cond + mutex *sync.Mutex + activeMap map[string]*types.Task } -func (q *TaskQueue) AddTask(task types.Task) { +func (q *TaskQueue) AddTask(task *types.Task) { q.mutex.Lock() defer q.mutex.Unlock() - q.list.PushBack(task) - q.cond.Signal() + if task.Status == types.Pending { + q.list.PushBack(task) + q.cond.Signal() + } else { + delete(q.activeMap, task.Key()) + } } -func (q *TaskQueue) GetTask() types.Task { +func (q *TaskQueue) GetTask() *types.Task { q.mutex.Lock() defer q.mutex.Unlock() for q.list.Len() == 0 { q.cond.Wait() } e := q.list.Front() - task := e.Value.(types.Task) + task := e.Value.(*types.Task) q.list.Remove(e) + q.activeMap[task.Key()] = task return task } +func (q *TaskQueue) CancelTask(key string) bool { + q.mutex.Lock() + defer q.mutex.Unlock() + if task, ok := q.activeMap[key]; ok { + if task.Cancel != nil { + task.Cancel() + return true + } + } + for e := q.list.Front(); e != nil; e = e.Next() { + task := e.Value.(*types.Task) + if task.Key() == key { + if task.Cancel != nil { + task.Cancel() + } + q.list.Remove(e) + return true + } + } + return false +} + func (q *TaskQueue) Len() int { q.mutex.Lock() defer q.mutex.Unlock() @@ -47,20 +75,25 @@ func init() { func NewQueue() *TaskQueue { m := &sync.Mutex{} return &TaskQueue{ - list: list.New(), - cond: sync.NewCond(m), - mutex: m, + list: list.New(), + cond: sync.NewCond(m), + mutex: m, + activeMap: make(map[string]*types.Task), } } -func AddTask(task types.Task) { +func AddTask(task *types.Task) { Queue.AddTask(task) } -func GetTask() types.Task { +func GetTask() *types.Task { return Queue.GetTask() } func Len() int { return Queue.Len() } + +func CancelTask(key string) bool { + return Queue.CancelTask(key) +} diff --git a/types/types.go b/types/types.go index 9d448f2..c2e0ca4 100644 --- a/types/types.go +++ b/types/types.go @@ -36,6 +36,7 @@ var StorageTypeDisplay = map[StorageType]string{ type Task struct { Ctx context.Context + Cancel context.CancelFunc Error error Status TaskStatus File *File @@ -52,6 +53,10 @@ type Task struct { UserID int64 } +func (t Task) Key() string { + return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID) +} + func (t Task) String() string { return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName) } From 7015081a84aaaf81eedef8005b6732cca711e50c Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:07:41 +0800 Subject: [PATCH 3/4] feat: add context cancellation handling in saveFileWithRetry function --- core/utils.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/utils.go b/core/utils.go index 3a069b7..ced499e 100644 --- a/core/utils.go +++ b/core/utils.go @@ -21,11 +21,19 @@ import ( func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error { for i := 0; i <= config.Cfg.Retry; i++ { + if err := task.Ctx.Err(); err != nil { + return fmt.Errorf("context canceled while saving file: %w", err) + } if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil { if i == config.Cfg.Retry { return fmt.Errorf("failed to save file: %w", err) } logger.L.Errorf("Failed to save file: %s, retrying...", err) + select { + case <-task.Ctx.Done(): + return fmt.Errorf("context canceled during retry delay: %w", task.Ctx.Err()) + case <-time.After(time.Duration(i*500) * time.Millisecond): + } continue } return nil From 152f4731315ad3bec525e5bc49ebc5248c1d2412 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:25:10 +0800 Subject: [PATCH 4/4] fix: delete done task --- core/core.go | 39 +++++++++++++++++++++++++++++---------- core/download.go | 3 +-- core/utils.go | 13 +++++++------ queue/queue.go | 21 ++++++++++++++++----- 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/core/core.go b/core/core.go index b2dcc62..4852d37 100644 --- a/core/core.go +++ b/core/core.go @@ -25,7 +25,6 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) { if err := processPendingTask(task); err != nil { task.Error = err if errors.Is(err, context.Canceled) { - logger.L.Debugf("Task canceled: %s", task.String()) task.Status = types.Canceled } else { logger.L.Errorf("Failed to do task: %s", err) @@ -37,23 +36,43 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) { queue.AddTask(task) case types.Succeeded: logger.L.Infof("Task succeeded: %s", task.String()) - task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath), - ID: task.ReplyMessageID, - }) + extCtx, ok := task.Ctx.(*ext.Context) + if !ok { + logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx) + } else { + extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath), + ID: task.ReplyMessageID, + }) + } case types.Failed: logger.L.Errorf("Task failed: %s", task.String()) - task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: "文件保存失败\n" + task.Error.Error(), - ID: task.ReplyMessageID, - }) + extCtx, ok := task.Ctx.(*ext.Context) + if !ok { + logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx) + } else { + extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: "文件保存失败\n" + task.Error.Error(), + ID: task.ReplyMessageID, + }) + } case types.Canceled: logger.L.Infof("Task canceled: %s", task.String()) + extCtx, ok := task.Ctx.(*ext.Context) + if !ok { + logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx) + } else { + extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: "任务已取消", + ID: task.ReplyMessageID, + }) + } default: logger.L.Errorf("Unknown task status: %s", task.Status) } <-semaphore - logger.L.Debugf("Task done: %s", task.String()) + logger.L.Debugf("Task done: %s; status: %s", task.String(), task.Status) + queue.DoneTask(task) } } diff --git a/core/download.go b/core/download.go index e2db1bb..cce4308 100644 --- a/core/download.go +++ b/core/download.go @@ -51,7 +51,6 @@ func processPendingTask(task *types.Task) error { cancelCtx, cancel := context.WithCancel(ctx) task.Cancel = cancel - task.Ctx = cancelCtx text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ @@ -84,5 +83,5 @@ func processPendingTask(task *types.Task) error { ID: task.ReplyMessageID, }) - return saveFileWithRetry(task, taskStorage, cacheDestPath) + return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath) } diff --git a/core/utils.go b/core/utils.go index ced499e..fed31cf 100644 --- a/core/utils.go +++ b/core/utils.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "os" "path" @@ -19,19 +20,19 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error { +func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, localFilePath string) error { for i := 0; i <= config.Cfg.Retry; i++ { - if err := task.Ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return fmt.Errorf("context canceled while saving file: %w", err) } - if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil { + if err := taskStorage.Save(ctx, localFilePath, task.StoragePath); err != nil { if i == config.Cfg.Retry { return fmt.Errorf("failed to save file: %w", err) } logger.L.Errorf("Failed to save file: %s, retrying...", err) select { - case <-task.Ctx.Done(): - return fmt.Errorf("context canceled during retry delay: %w", task.Ctx.Err()) + case <-ctx.Done(): + return fmt.Errorf("context canceled during retry delay: %w", ctx.Err()) case <-time.After(time.Duration(i*500) * time.Millisecond): } continue @@ -64,7 +65,7 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin logger.L.Infof("Downloaded file: %s", cachePath) - return saveFileWithRetry(task, taskStorage, cachePath) + return saveFileWithRetry(task.Ctx, task, taskStorage, cachePath) } // func getProgressBar(progress float64, updateCount int) string { diff --git a/queue/queue.go b/queue/queue.go index 0434c8c..78401fc 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -17,10 +17,9 @@ type TaskQueue struct { func (q *TaskQueue) AddTask(task *types.Task) { q.mutex.Lock() defer q.mutex.Unlock() - if task.Status == types.Pending { - q.list.PushBack(task) - q.cond.Signal() - } else { + q.list.PushBack(task) + q.cond.Signal() + if task.Status != types.Pending { delete(q.activeMap, task.Key()) } } @@ -34,10 +33,18 @@ func (q *TaskQueue) GetTask() *types.Task { e := q.list.Front() task := e.Value.(*types.Task) q.list.Remove(e) - q.activeMap[task.Key()] = task + if task.Status == types.Pending { + q.activeMap[task.Key()] = task + } return task } +func (q *TaskQueue) DoneTask(task *types.Task) { + q.mutex.Lock() + defer q.mutex.Unlock() + delete(q.activeMap, task.Key()) +} + func (q *TaskQueue) CancelTask(key string) bool { q.mutex.Lock() defer q.mutex.Unlock() @@ -97,3 +104,7 @@ func Len() int { func CancelTask(key string) bool { return Queue.CancelTask(key) } + +func DoneTask(task *types.Task) { + Queue.DoneTask(task) +}