feat: implement task cancellation feature and update task handling
This commit is contained in:
@@ -163,7 +163,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
task.StoragePath = path.Join(dir.Path, file.FileName)
|
||||
}
|
||||
|
||||
queue.AddTask(task)
|
||||
queue.AddTask(&task)
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
|
||||
27
bot/handle_cancel_task.go
Normal file
27
bot/handle_cancel_task.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
)
|
||||
|
||||
func cancelTask(ctx *ext.Context, update *ext.Update) error {
|
||||
key := strings.Split(string(update.CallbackQuery.Data), " ")[1]
|
||||
ok := queue.CancelTask(key)
|
||||
if ok {
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Message: "任务已取消",
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Message: "任务取消失败",
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -22,5 +22,6 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
||||
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
|
||||
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, t
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
queue.AddTask(*task)
|
||||
queue.AddTask(task)
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()),
|
||||
ID: task.ReplyMessageID,
|
||||
|
||||
@@ -22,13 +22,13 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
switch task.Status {
|
||||
case types.Pending:
|
||||
logger.L.Infof("Processing task: %s", task.String())
|
||||
if err := processPendingTask(&task); err != nil {
|
||||
logger.L.Errorf("Failed to do task: %s", err)
|
||||
if err := processPendingTask(task); err != nil {
|
||||
task.Error = err
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.L.Debugf("Task canceled: %s", task.String())
|
||||
task.Status = types.Canceled
|
||||
} else {
|
||||
logger.L.Errorf("Failed to do task: %s", err)
|
||||
task.Status = types.Failed
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
@@ -48,11 +49,16 @@ func processPendingTask(task *types.Task) error {
|
||||
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
|
||||
}
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
task.Cancel = cancel
|
||||
task.Ctx = cancelCtx
|
||||
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
|
||||
@@ -63,7 +69,7 @@ func processPendingTask(task *types.Task) error {
|
||||
defer dest.Close()
|
||||
task.StartTime = time.Now()
|
||||
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
|
||||
_, err = downloadBuider.Parallel(ctx, dest)
|
||||
_, err = downloadBuider.Parallel(cancelCtx, dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("下载文件失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -139,13 +139,20 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int)
|
||||
}
|
||||
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
|
||||
return &tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
|
||||
}
|
||||
}
|
||||
|
||||
func fixTaskFileExt(task *types.Task, localFilePath string) {
|
||||
if path.Ext(task.FileName()) == "" {
|
||||
mimeType, err := mimetype.DetectFile(localFilePath)
|
||||
|
||||
@@ -8,30 +8,58 @@ import (
|
||||
)
|
||||
|
||||
type TaskQueue struct {
|
||||
list *list.List
|
||||
cond *sync.Cond
|
||||
mutex *sync.Mutex
|
||||
list *list.List
|
||||
cond *sync.Cond
|
||||
mutex *sync.Mutex
|
||||
activeMap map[string]*types.Task
|
||||
}
|
||||
|
||||
func (q *TaskQueue) AddTask(task types.Task) {
|
||||
func (q *TaskQueue) AddTask(task *types.Task) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.list.PushBack(task)
|
||||
q.cond.Signal()
|
||||
if task.Status == types.Pending {
|
||||
q.list.PushBack(task)
|
||||
q.cond.Signal()
|
||||
} else {
|
||||
delete(q.activeMap, task.Key())
|
||||
}
|
||||
}
|
||||
|
||||
func (q *TaskQueue) GetTask() types.Task {
|
||||
func (q *TaskQueue) GetTask() *types.Task {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
for q.list.Len() == 0 {
|
||||
q.cond.Wait()
|
||||
}
|
||||
e := q.list.Front()
|
||||
task := e.Value.(types.Task)
|
||||
task := e.Value.(*types.Task)
|
||||
q.list.Remove(e)
|
||||
q.activeMap[task.Key()] = task
|
||||
return task
|
||||
}
|
||||
|
||||
func (q *TaskQueue) CancelTask(key string) bool {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
if task, ok := q.activeMap[key]; ok {
|
||||
if task.Cancel != nil {
|
||||
task.Cancel()
|
||||
return true
|
||||
}
|
||||
}
|
||||
for e := q.list.Front(); e != nil; e = e.Next() {
|
||||
task := e.Value.(*types.Task)
|
||||
if task.Key() == key {
|
||||
if task.Cancel != nil {
|
||||
task.Cancel()
|
||||
}
|
||||
q.list.Remove(e)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (q *TaskQueue) Len() int {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
@@ -47,20 +75,25 @@ func init() {
|
||||
func NewQueue() *TaskQueue {
|
||||
m := &sync.Mutex{}
|
||||
return &TaskQueue{
|
||||
list: list.New(),
|
||||
cond: sync.NewCond(m),
|
||||
mutex: m,
|
||||
list: list.New(),
|
||||
cond: sync.NewCond(m),
|
||||
mutex: m,
|
||||
activeMap: make(map[string]*types.Task),
|
||||
}
|
||||
}
|
||||
|
||||
func AddTask(task types.Task) {
|
||||
func AddTask(task *types.Task) {
|
||||
Queue.AddTask(task)
|
||||
}
|
||||
|
||||
func GetTask() types.Task {
|
||||
func GetTask() *types.Task {
|
||||
return Queue.GetTask()
|
||||
}
|
||||
|
||||
func Len() int {
|
||||
return Queue.Len()
|
||||
}
|
||||
|
||||
func CancelTask(key string) bool {
|
||||
return Queue.CancelTask(key)
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ var StorageTypeDisplay = map[StorageType]string{
|
||||
|
||||
type Task struct {
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
Error error
|
||||
Status TaskStatus
|
||||
File *File
|
||||
@@ -52,6 +53,10 @@ type Task struct {
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (t Task) Key() string {
|
||||
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
|
||||
}
|
||||
|
||||
func (t Task) String() string {
|
||||
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user