diff --git a/bot/handle_add_task.go b/bot/handle_add_task.go index 5804217..167caff 100644 --- a/bot/handle_add_task.go +++ b/bot/handle_add_task.go @@ -163,7 +163,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { task.StoragePath = path.Join(dir.Path, file.FileName) } - queue.AddTask(task) + queue.AddTask(&task) entityBuilder := entity.Builder{} var entities []tg.MessageEntityClass diff --git a/bot/handle_cancel_task.go b/bot/handle_cancel_task.go new file mode 100644 index 0000000..512df00 --- /dev/null +++ b/bot/handle_cancel_task.go @@ -0,0 +1,27 @@ +package bot + +import ( + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/queue" +) + +func cancelTask(ctx *ext.Context, update *ext.Update) error { + key := strings.Split(string(update.CallbackQuery.Data), " ")[1] + ok := queue.CancelTask(key) + if ok { + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Message: "任务已取消", + }) + return dispatcher.EndGroups + } + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Message: "任务取消失败", + }) + return dispatcher.EndGroups +} diff --git a/bot/handlers.go b/bot/handlers.go index 7f5f2c0..3631fa5 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -22,5 +22,6 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) { dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage)) + dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask)) dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage)) } diff --git a/bot/utils.go b/bot/utils.go index 1c42c89..e008e9c 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -264,7 +264,7 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, t }) return dispatcher.EndGroups } - queue.AddTask(*task) + queue.AddTask(task) ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()), ID: task.ReplyMessageID, diff --git a/core/core.go b/core/core.go index c91f043..b2dcc62 100644 --- a/core/core.go +++ b/core/core.go @@ -22,13 +22,13 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) { switch task.Status { case types.Pending: logger.L.Infof("Processing task: %s", task.String()) - if err := processPendingTask(&task); err != nil { - logger.L.Errorf("Failed to do task: %s", err) + if err := processPendingTask(task); err != nil { task.Error = err if errors.Is(err, context.Canceled) { logger.L.Debugf("Task canceled: %s", task.String()) task.Status = types.Canceled } else { + logger.L.Errorf("Failed to do task: %s", err) task.Status = types.Failed } } else { diff --git a/core/download.go b/core/download.go index 960bfd1..e2db1bb 100644 --- a/core/download.go +++ b/core/download.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "path/filepath" "time" @@ -48,11 +49,16 @@ func processPendingTask(task *types.Task) error { return fmt.Errorf("context is not *ext.Context: %T", task.Ctx) } + cancelCtx, cancel := context.WithCancel(ctx) + task.Cancel = cancel + task.Ctx = cancelCtx + text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + ReplyMarkup: getCancelTaskMarkup(task), }) progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) @@ -63,7 +69,7 @@ func processPendingTask(task *types.Task) error { defer dest.Close() task.StartTime = time.Now() downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) - _, err = downloadBuider.Parallel(ctx, dest) + _, err = downloadBuider.Parallel(cancelCtx, dest) if err != nil { return fmt.Errorf("下载文件失败: %w", err) } diff --git a/core/utils.go b/core/utils.go index d785d20..3a069b7 100644 --- a/core/utils.go +++ b/core/utils.go @@ -139,13 +139,20 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) } text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: text, - Entities: entities, - ID: task.ReplyMessageID, + Message: text, + Entities: entities, + ID: task.ReplyMessageID, + ReplyMarkup: getCancelTaskMarkup(task), }) } } +func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup { + return &tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}}, + } +} + func fixTaskFileExt(task *types.Task, localFilePath string) { if path.Ext(task.FileName()) == "" { mimeType, err := mimetype.DetectFile(localFilePath) diff --git a/queue/queue.go b/queue/queue.go index 4ac8733..0434c8c 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -8,30 +8,58 @@ import ( ) type TaskQueue struct { - list *list.List - cond *sync.Cond - mutex *sync.Mutex + list *list.List + cond *sync.Cond + mutex *sync.Mutex + activeMap map[string]*types.Task } -func (q *TaskQueue) AddTask(task types.Task) { +func (q *TaskQueue) AddTask(task *types.Task) { q.mutex.Lock() defer q.mutex.Unlock() - q.list.PushBack(task) - q.cond.Signal() + if task.Status == types.Pending { + q.list.PushBack(task) + q.cond.Signal() + } else { + delete(q.activeMap, task.Key()) + } } -func (q *TaskQueue) GetTask() types.Task { +func (q *TaskQueue) GetTask() *types.Task { q.mutex.Lock() defer q.mutex.Unlock() for q.list.Len() == 0 { q.cond.Wait() } e := q.list.Front() - task := e.Value.(types.Task) + task := e.Value.(*types.Task) q.list.Remove(e) + q.activeMap[task.Key()] = task return task } +func (q *TaskQueue) CancelTask(key string) bool { + q.mutex.Lock() + defer q.mutex.Unlock() + if task, ok := q.activeMap[key]; ok { + if task.Cancel != nil { + task.Cancel() + return true + } + } + for e := q.list.Front(); e != nil; e = e.Next() { + task := e.Value.(*types.Task) + if task.Key() == key { + if task.Cancel != nil { + task.Cancel() + } + q.list.Remove(e) + return true + } + } + return false +} + func (q *TaskQueue) Len() int { q.mutex.Lock() defer q.mutex.Unlock() @@ -47,20 +75,25 @@ func init() { func NewQueue() *TaskQueue { m := &sync.Mutex{} return &TaskQueue{ - list: list.New(), - cond: sync.NewCond(m), - mutex: m, + list: list.New(), + cond: sync.NewCond(m), + mutex: m, + activeMap: make(map[string]*types.Task), } } -func AddTask(task types.Task) { +func AddTask(task *types.Task) { Queue.AddTask(task) } -func GetTask() types.Task { +func GetTask() *types.Task { return Queue.GetTask() } func Len() int { return Queue.Len() } + +func CancelTask(key string) bool { + return Queue.CancelTask(key) +} diff --git a/types/types.go b/types/types.go index 9d448f2..c2e0ca4 100644 --- a/types/types.go +++ b/types/types.go @@ -36,6 +36,7 @@ var StorageTypeDisplay = map[StorageType]string{ type Task struct { Ctx context.Context + Cancel context.CancelFunc Error error Status TaskStatus File *File @@ -52,6 +53,10 @@ type Task struct { UserID int64 } +func (t Task) Key() string { + return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID) +} + func (t Task) String() string { return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName) }