refactor: refactor task logic for better scalability (#76)

* refactor: a big refactor. wip

* refactor: port handle file

* refactor: place all handlers

* fix: task info nil pointer

* feat: enhance task progress tracking and context management

* feat: cancel task

* feat: stream mode

* feat: silent mode

* feat: dir cmd

* refactor: remove unused old file

* feat: rule cmd

* feat: handle silent mode

* feat: batch task

* fix: batch task progress and temp file cleanup

* refactor: update file creation and cleanup methods for better resource management

* feat: add save command with silent mode handling

* feat: message link

* feat: update message prompts to include file count in storage selection

* feat: slient save links

* refactor: reduce dup code

* feat: rule type

* feat: chose dir

* feat: refactor file handling and storage rules, improve error handling and logging

* feat: rule mode

* feat: telegraph pics

* fix: tphpics nil pointer and inaccurate dirpath

* feat: silent save telegraph

* feat: add suffix to avoid file overwrite

* feat: new storage telegram

* chore: tidy go mod
This commit is contained in:
Krau
2025-06-15 23:57:49 +08:00
committed by GitHub
parent 280745cae3
commit 900823cdb9
150 changed files with 5730 additions and 3923 deletions

82
core/tftask/execute.go Normal file
View File

@@ -0,0 +1,82 @@
package tftask
import (
"context"
"fmt"
"os"
"path"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/key"
)
func (t *TGFileTask) Execute(ctx context.Context) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name()))
t.Progress.OnStart(ctx, t)
if t.stream {
return executeStream(ctx, t)
}
logger.Info("Starting file download")
localFile, err := fsutil.CreateFile(t.localPath)
if err != nil {
return fmt.Errorf("failed to create local file: %w", err)
}
defer func() {
if err := localFile.CloseAndRemove(); err != nil {
logger.Errorf("Failed to close local file: %v", err)
}
}()
wrAt := newWriterAt(ctx, localFile, t.Progress, t)
defer func() {
t.Progress.OnDone(ctx, t, err)
}()
_, err = tdler.NewDownloader(t.client, t.File).Parallel(ctx, wrAt)
if err != nil {
return fmt.Errorf("failed to download file: %w", err)
}
logger.Infof("File downloaded successfully")
if path.Ext(t.File.Name()) == "" {
ext := fsutil.DetectFileExt(t.localPath)
if ext != "" {
t.Path = t.Path + ext
}
}
var fileStat os.FileInfo
fileStat, err = os.Stat(t.localPath)
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size())
for i := range config.Cfg.Retry + 1 {
if err = vctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
}
var file *os.File
file, err = os.Open(t.localPath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
if err = t.Storage.Save(vctx, file, t.Path); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
logger.Errorf("Failed to save file: %s, retrying...", err)
select {
case <-vctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond):
}
continue
}
return nil
}
return fmt.Errorf("failed to save file after retries")
}

186
core/tftask/progress.go Normal file
View File

@@ -0,0 +1,186 @@
package tftask
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
)
type ProgressTracker interface {
OnStart(ctx context.Context, info TaskInfo)
OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64)
OnDone(ctx context.Context, info TaskInfo, err error)
}
type Progress struct {
MessageID int
ChatID int64
start time.Time
lastUpdatePercent atomic.Int32
}
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
p.start = time.Now()
p.lastUpdatePercent.Store(0)
log.FromContext(ctx).Debugf("Progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("开始下载\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
styling.Plain("\n文件大小: "),
styling.Code(fmt.Sprintf("%.2f MB", float64(info.FileSize())/(1024*1024))),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64) {
if !shouldUpdateProgress(total, downloaded, int(p.lastUpdatePercent.Load())) {
return
}
percent := int32((downloaded * 100) / total)
if p.lastUpdatePercent.Load() == percent {
return
}
p.lastUpdatePercent.Store(percent)
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.FileName(), downloaded, total)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
styling.Plain("\n文件大小: "),
styling.Code(fmt.Sprintf("%.2f MB", float64(total)/(1024*1024))),
styling.Plain("\n平均速度: "),
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(downloaded, p.start)/(1024*1024))),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", float64(downloaded)/float64(total)*100)),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
if err != nil {
log.FromContext(ctx).Errorf("Progress error for file [%s]: %v", info.FileName(), err)
} else {
log.FromContext(ctx).Debugf("Progress done for file [%s]", info.FileName())
}
entityBuilder := entity.Builder{}
var stylingErr error
if err != nil {
if errors.Is(err, context.Canceled) {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("任务已取消\n文件名: "),
styling.Code(info.FileName()),
)
} else {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("下载失败\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n错误: "),
styling.Bold(err.Error()),
)
}
} else {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("下载完成\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
)
}
if stylingErr != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
}
}
type ProgressOption func(*Progress)
func NewProgressTrack(
messageID int,
chatID int64,
opts ...ProgressOption,
) ProgressTracker {
p := &Progress{
MessageID: messageID,
ChatID: chatID,
}
for _, opt := range opts {
opt(p)
}
return p
}

40
core/tftask/stream.go Normal file
View File

@@ -0,0 +1,40 @@
package tftask
import (
"context"
"fmt"
"io"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/common/tdler"
"golang.org/x/sync/errgroup"
)
func executeStream(ctx context.Context, task *TGFileTask) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name()))
pr, pw := io.Pipe()
defer pr.Close()
errg, uploadCtx := errgroup.WithContext(ctx)
errg.Go(func() error {
return task.Storage.Save(uploadCtx, pr, task.Path)
})
wr := newWriter(ctx, pw, task.Progress, task)
errg.Go(func() error {
logger.Info("Starting file download in stream mode")
_, err := tdler.NewDownloader(task.client, task.File).Stream(uploadCtx, wr)
if closeErr := pw.CloseWithError(err); closeErr != nil {
logger.Errorf("Failed to close pipe writer: %v", closeErr)
}
return err
})
var err error
defer func() {
task.Progress.OnDone(ctx, task, err)
}()
if err = errg.Wait(); err != nil {
return err
}
logger.Info("File downloaded successfully in stream mode")
return nil
}

29
core/tftask/taskinfo.go Normal file
View File

@@ -0,0 +1,29 @@
package tftask
type TaskInfo interface {
TaskID() string
FileName() string
FileSize() int64
StoragePath() string
StorageName() string
}
func (t *TGFileTask) TaskID() string {
return t.ID
}
func (t *TGFileTask) FileName() string {
return t.File.Name()
}
func (t *TGFileTask) FileSize() int64 {
return t.File.Size()
}
func (t *TGFileTask) StoragePath() string {
return t.Path
}
func (t *TGFileTask) StorageName() string {
return t.Storage.Name()
}

64
core/tftask/tftask.go Normal file
View File

@@ -0,0 +1,64 @@
package tftask
import (
"context"
"fmt"
"path/filepath"
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
)
type TGFileTask struct {
ID string
Ctx context.Context
File tfile.TGFile
Storage storage.Storage
Path string
Progress ProgressTracker
client tdler.Client
stream bool // true if the file should be downloaded in stream mode
localPath string
}
func NewTGFileTask(
id string,
ctx context.Context,
file tfile.TGFile,
client tdler.Client,
stor storage.Storage,
path string,
progress ProgressTracker,
) (*TGFileTask, error) {
_, ok := stor.(storage.StorageCannotStream)
if !config.Cfg.Stream || ok {
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
if err != nil {
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
}
tftask := &TGFileTask{
ID: id,
Ctx: ctx,
client: client,
File: file,
Storage: stor,
Path: path,
Progress: progress,
localPath: cachePath,
}
return tftask, nil
}
tfileTask := &TGFileTask{
ID: id,
Ctx: ctx,
client: client,
File: file,
Storage: stor,
Path: path,
Progress: progress,
stream: true,
}
return tfileTask, nil
}

32
core/tftask/util.go Normal file
View File

@@ -0,0 +1,32 @@
package tftask
var progressUpdatesLevels = []struct {
size int64 // 文件大小阈值
stepPercent int // 每多少 % 更新一次
}{
{10 << 20, 100},
{50 << 20, 20},
{200 << 20, 10},
{500 << 20, 5},
}
func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool {
if total <= 0 || downloaded <= 0 {
return false
}
percent := int((downloaded * 100) / total)
if percent <= lastUpdatePercent {
return false
}
step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent
for _, lvl := range progressUpdatesLevels {
if total < lvl.size {
step = lvl.stepPercent
break
}
}
return percent >= lastUpdatePercent+step
}

75
core/tftask/writer.go Normal file
View File

@@ -0,0 +1,75 @@
package tftask
import (
"context"
"io"
"sync/atomic"
)
type ProgressWriterAt struct {
ctx context.Context
wrAt io.WriterAt
progress ProgressTracker
downloaded *atomic.Int64
total int64
info TaskInfo
}
func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
at, err := w.wrAt.WriteAt(p, off)
if err != nil {
return 0, err
}
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
return at, nil
}
func newWriterAt(
ctx context.Context,
wrAt io.WriterAt,
progress ProgressTracker,
taskInfo TaskInfo,
) *ProgressWriterAt {
return &ProgressWriterAt{
ctx: ctx,
progress: progress,
downloaded: &atomic.Int64{},
total: taskInfo.FileSize(),
wrAt: wrAt,
info: taskInfo,
}
}
type ProgressWriter struct {
ctx context.Context
wrAt io.Writer
progress ProgressTracker
downloaded *atomic.Int64
total int64
info TaskInfo
}
func (w *ProgressWriter) Write(p []byte) (int, error) {
at, err := w.wrAt.Write(p)
if err != nil {
return 0, err
}
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
return at, nil
}
func newWriter(
ctx context.Context,
wr io.Writer,
progress ProgressTracker,
taskInfo TaskInfo,
) *ProgressWriter {
return &ProgressWriter{
ctx: ctx,
progress: progress,
downloaded: &atomic.Int64{},
total: taskInfo.FileSize(),
wrAt: wr,
info: taskInfo,
}
}