diff --git a/core/core.go b/core/core.go index b2dcc62..4852d37 100644 --- a/core/core.go +++ b/core/core.go @@ -25,7 +25,6 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) { if err := processPendingTask(task); err != nil { task.Error = err if errors.Is(err, context.Canceled) { - logger.L.Debugf("Task canceled: %s", task.String()) task.Status = types.Canceled } else { logger.L.Errorf("Failed to do task: %s", err) @@ -37,23 +36,43 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) { queue.AddTask(task) case types.Succeeded: logger.L.Infof("Task succeeded: %s", task.String()) - task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath), - ID: task.ReplyMessageID, - }) + extCtx, ok := task.Ctx.(*ext.Context) + if !ok { + logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx) + } else { + extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath), + ID: task.ReplyMessageID, + }) + } case types.Failed: logger.L.Errorf("Task failed: %s", task.String()) - task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: "文件保存失败\n" + task.Error.Error(), - ID: task.ReplyMessageID, - }) + extCtx, ok := task.Ctx.(*ext.Context) + if !ok { + logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx) + } else { + extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: "文件保存失败\n" + task.Error.Error(), + ID: task.ReplyMessageID, + }) + } case types.Canceled: logger.L.Infof("Task canceled: %s", task.String()) + extCtx, ok := task.Ctx.(*ext.Context) + if !ok { + logger.L.Errorf("Context is not *ext.Context: %T", task.Ctx) + } else { + extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ + Message: "任务已取消", + ID: task.ReplyMessageID, + }) + } default: logger.L.Errorf("Unknown task status: %s", task.Status) } <-semaphore - logger.L.Debugf("Task done: %s", task.String()) + logger.L.Debugf("Task done: %s; status: %s", task.String(), task.Status) + queue.DoneTask(task) } } diff --git a/core/download.go b/core/download.go index e2db1bb..cce4308 100644 --- a/core/download.go +++ b/core/download.go @@ -51,7 +51,6 @@ func processPendingTask(task *types.Task) error { cancelCtx, cancel := context.WithCancel(ctx) task.Cancel = cancel - task.Ctx = cancelCtx text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ @@ -84,5 +83,5 @@ func processPendingTask(task *types.Task) error { ID: task.ReplyMessageID, }) - return saveFileWithRetry(task, taskStorage, cacheDestPath) + return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath) } diff --git a/core/utils.go b/core/utils.go index ced499e..fed31cf 100644 --- a/core/utils.go +++ b/core/utils.go @@ -1,6 +1,7 @@ package core import ( + "context" "fmt" "os" "path" @@ -19,19 +20,19 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error { +func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, localFilePath string) error { for i := 0; i <= config.Cfg.Retry; i++ { - if err := task.Ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return fmt.Errorf("context canceled while saving file: %w", err) } - if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil { + if err := taskStorage.Save(ctx, localFilePath, task.StoragePath); err != nil { if i == config.Cfg.Retry { return fmt.Errorf("failed to save file: %w", err) } logger.L.Errorf("Failed to save file: %s, retrying...", err) select { - case <-task.Ctx.Done(): - return fmt.Errorf("context canceled during retry delay: %w", task.Ctx.Err()) + case <-ctx.Done(): + return fmt.Errorf("context canceled during retry delay: %w", ctx.Err()) case <-time.After(time.Duration(i*500) * time.Millisecond): } continue @@ -64,7 +65,7 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin logger.L.Infof("Downloaded file: %s", cachePath) - return saveFileWithRetry(task, taskStorage, cachePath) + return saveFileWithRetry(task.Ctx, task, taskStorage, cachePath) } // func getProgressBar(progress float64, updateCount int) string { diff --git a/queue/queue.go b/queue/queue.go index 0434c8c..78401fc 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -17,10 +17,9 @@ type TaskQueue struct { func (q *TaskQueue) AddTask(task *types.Task) { q.mutex.Lock() defer q.mutex.Unlock() - if task.Status == types.Pending { - q.list.PushBack(task) - q.cond.Signal() - } else { + q.list.PushBack(task) + q.cond.Signal() + if task.Status != types.Pending { delete(q.activeMap, task.Key()) } } @@ -34,10 +33,18 @@ func (q *TaskQueue) GetTask() *types.Task { e := q.list.Front() task := e.Value.(*types.Task) q.list.Remove(e) - q.activeMap[task.Key()] = task + if task.Status == types.Pending { + q.activeMap[task.Key()] = task + } return task } +func (q *TaskQueue) DoneTask(task *types.Task) { + q.mutex.Lock() + defer q.mutex.Unlock() + delete(q.activeMap, task.Key()) +} + func (q *TaskQueue) CancelTask(key string) bool { q.mutex.Lock() defer q.mutex.Unlock() @@ -97,3 +104,7 @@ func Len() int { func CancelTask(key string) bool { return Queue.CancelTask(key) } + +func DoneTask(task *types.Task) { + Queue.DoneTask(task) +}