mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-05-12 01:39:42 +08:00
refactor: refactor task logic for better scalability (#76)
* refactor: a big refactor. wip * refactor: port handle file * refactor: place all handlers * fix: task info nil pointer * feat: enhance task progress tracking and context management * feat: cancel task * feat: stream mode * feat: silent mode * feat: dir cmd * refactor: remove unused old file * feat: rule cmd * feat: handle silent mode * feat: batch task * fix: batch task progress and temp file cleanup * refactor: update file creation and cleanup methods for better resource management * feat: add save command with silent mode handling * feat: message link * feat: update message prompts to include file count in storage selection * feat: slient save links * refactor: reduce dup code * feat: rule type * feat: chose dir * feat: refactor file handling and storage rules, improve error handling and logging * feat: rule mode * feat: telegraph pics * fix: tphpics nil pointer and inaccurate dirpath * feat: silent save telegraph * feat: add suffix to avoid file overwrite * feat: new storage telegram * chore: tidy go mod
This commit is contained in:
82
core/tftask/execute.go
Normal file
82
core/tftask/execute.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/key"
|
||||
)
|
||||
|
||||
func (t *TGFileTask) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name()))
|
||||
t.Progress.OnStart(ctx, t)
|
||||
if t.stream {
|
||||
return executeStream(ctx, t)
|
||||
}
|
||||
|
||||
logger.Info("Starting file download")
|
||||
localFile, err := fsutil.CreateFile(t.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := localFile.CloseAndRemove(); err != nil {
|
||||
logger.Errorf("Failed to close local file: %v", err)
|
||||
}
|
||||
}()
|
||||
wrAt := newWriterAt(ctx, localFile, t.Progress, t)
|
||||
|
||||
defer func() {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}()
|
||||
_, err = tdler.NewDownloader(t.client, t.File).Parallel(ctx, wrAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file: %w", err)
|
||||
}
|
||||
logger.Infof("File downloaded successfully")
|
||||
if path.Ext(t.File.Name()) == "" {
|
||||
ext := fsutil.DetectFileExt(t.localPath)
|
||||
if ext != "" {
|
||||
t.Path = t.Path + ext
|
||||
}
|
||||
}
|
||||
var fileStat os.FileInfo
|
||||
fileStat, err = os.Stat(t.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file stat: %w", err)
|
||||
}
|
||||
vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size())
|
||||
for i := range config.Cfg.Retry + 1 {
|
||||
if err = vctx.Err(); err != nil {
|
||||
return fmt.Errorf("context canceled while saving file: %w", err)
|
||||
}
|
||||
var file *os.File
|
||||
file, err = os.Open(t.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
if err = t.Storage.Save(vctx, file, t.Path); err != nil {
|
||||
if i == config.Cfg.Retry {
|
||||
return fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
logger.Errorf("Failed to save file: %s, retrying...", err)
|
||||
select {
|
||||
case <-vctx.Done():
|
||||
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
|
||||
case <-time.After(time.Duration(i*500) * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to save file after retries")
|
||||
|
||||
}
|
||||
186
core/tftask/progress.go
Normal file
186
core/tftask/progress.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
)
|
||||
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, info TaskInfo)
|
||||
OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64)
|
||||
OnDone(ctx context.Context, info TaskInfo, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
MessageID int
|
||||
ChatID int64
|
||||
start time.Time
|
||||
lastUpdatePercent atomic.Int32
|
||||
}
|
||||
|
||||
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
|
||||
p.start = time.Now()
|
||||
p.lastUpdatePercent.Store(0)
|
||||
log.FromContext(ctx).Debugf("Progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("开始下载\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
|
||||
styling.Plain("\n文件大小: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB", float64(info.FileSize())/(1024*1024))),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64) {
|
||||
if !shouldUpdateProgress(total, downloaded, int(p.lastUpdatePercent.Load())) {
|
||||
return
|
||||
}
|
||||
percent := int32((downloaded * 100) / total)
|
||||
if p.lastUpdatePercent.Load() == percent {
|
||||
return
|
||||
}
|
||||
p.lastUpdatePercent.Store(percent)
|
||||
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.FileName(), downloaded, total)
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在处理下载任务\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
|
||||
styling.Plain("\n文件大小: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB", float64(total)/(1024*1024))),
|
||||
styling.Plain("\n平均速度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(downloaded, p.start)/(1024*1024))),
|
||||
styling.Plain("\n当前进度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", float64(downloaded)/float64(total)*100)),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Errorf("Progress error for file [%s]: %v", info.FileName(), err)
|
||||
} else {
|
||||
log.FromContext(ctx).Debugf("Progress done for file [%s]", info.FileName())
|
||||
}
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
var stylingErr error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("任务已取消\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
)
|
||||
} else {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("下载失败\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n错误: "),
|
||||
styling.Bold(err.Error()),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("下载完成\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
|
||||
)
|
||||
}
|
||||
|
||||
if stylingErr != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr)
|
||||
return
|
||||
}
|
||||
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
type ProgressOption func(*Progress)
|
||||
|
||||
func NewProgressTrack(
|
||||
messageID int,
|
||||
chatID int64,
|
||||
opts ...ProgressOption,
|
||||
) ProgressTracker {
|
||||
p := &Progress{
|
||||
MessageID: messageID,
|
||||
ChatID: chatID,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
return p
|
||||
}
|
||||
40
core/tftask/stream.go
Normal file
40
core/tftask/stream.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func executeStream(ctx context.Context, task *TGFileTask) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name()))
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
errg, uploadCtx := errgroup.WithContext(ctx)
|
||||
errg.Go(func() error {
|
||||
return task.Storage.Save(uploadCtx, pr, task.Path)
|
||||
})
|
||||
wr := newWriter(ctx, pw, task.Progress, task)
|
||||
errg.Go(func() error {
|
||||
logger.Info("Starting file download in stream mode")
|
||||
_, err := tdler.NewDownloader(task.client, task.File).Stream(uploadCtx, wr)
|
||||
if closeErr := pw.CloseWithError(err); closeErr != nil {
|
||||
logger.Errorf("Failed to close pipe writer: %v", closeErr)
|
||||
}
|
||||
return err
|
||||
})
|
||||
var err error
|
||||
defer func() {
|
||||
task.Progress.OnDone(ctx, task, err)
|
||||
}()
|
||||
if err = errg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("File downloaded successfully in stream mode")
|
||||
return nil
|
||||
}
|
||||
29
core/tftask/taskinfo.go
Normal file
29
core/tftask/taskinfo.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package tftask
|
||||
|
||||
type TaskInfo interface {
|
||||
TaskID() string
|
||||
FileName() string
|
||||
FileSize() int64
|
||||
StoragePath() string
|
||||
StorageName() string
|
||||
}
|
||||
|
||||
func (t *TGFileTask) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func (t *TGFileTask) FileName() string {
|
||||
return t.File.Name()
|
||||
}
|
||||
|
||||
func (t *TGFileTask) FileSize() int64 {
|
||||
return t.File.Size()
|
||||
}
|
||||
|
||||
func (t *TGFileTask) StoragePath() string {
|
||||
return t.Path
|
||||
}
|
||||
|
||||
func (t *TGFileTask) StorageName() string {
|
||||
return t.Storage.Name()
|
||||
}
|
||||
64
core/tftask/tftask.go
Normal file
64
core/tftask/tftask.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
type TGFileTask struct {
|
||||
ID string
|
||||
Ctx context.Context
|
||||
File tfile.TGFile
|
||||
Storage storage.Storage
|
||||
Path string
|
||||
Progress ProgressTracker
|
||||
client tdler.Client
|
||||
stream bool // true if the file should be downloaded in stream mode
|
||||
localPath string
|
||||
}
|
||||
|
||||
func NewTGFileTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
file tfile.TGFile,
|
||||
client tdler.Client,
|
||||
stor storage.Storage,
|
||||
path string,
|
||||
progress ProgressTracker,
|
||||
) (*TGFileTask, error) {
|
||||
_, ok := stor.(storage.StorageCannotStream)
|
||||
if !config.Cfg.Stream || ok {
|
||||
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
|
||||
}
|
||||
tftask := &TGFileTask{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
client: client,
|
||||
File: file,
|
||||
Storage: stor,
|
||||
Path: path,
|
||||
Progress: progress,
|
||||
localPath: cachePath,
|
||||
}
|
||||
return tftask, nil
|
||||
}
|
||||
tfileTask := &TGFileTask{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
client: client,
|
||||
File: file,
|
||||
Storage: stor,
|
||||
Path: path,
|
||||
Progress: progress,
|
||||
stream: true,
|
||||
}
|
||||
return tfileTask, nil
|
||||
}
|
||||
32
core/tftask/util.go
Normal file
32
core/tftask/util.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package tftask
|
||||
|
||||
var progressUpdatesLevels = []struct {
|
||||
size int64 // 文件大小阈值
|
||||
stepPercent int // 每多少 % 更新一次
|
||||
}{
|
||||
{10 << 20, 100},
|
||||
{50 << 20, 20},
|
||||
{200 << 20, 10},
|
||||
{500 << 20, 5},
|
||||
}
|
||||
|
||||
func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool {
|
||||
if total <= 0 || downloaded <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
percent := int((downloaded * 100) / total)
|
||||
if percent <= lastUpdatePercent {
|
||||
return false
|
||||
}
|
||||
|
||||
step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent
|
||||
for _, lvl := range progressUpdatesLevels {
|
||||
if total < lvl.size {
|
||||
step = lvl.stepPercent
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return percent >= lastUpdatePercent+step
|
||||
}
|
||||
75
core/tftask/writer.go
Normal file
75
core/tftask/writer.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type ProgressWriterAt struct {
|
||||
ctx context.Context
|
||||
wrAt io.WriterAt
|
||||
progress ProgressTracker
|
||||
downloaded *atomic.Int64
|
||||
total int64
|
||||
info TaskInfo
|
||||
}
|
||||
|
||||
func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
|
||||
at, err := w.wrAt.WriteAt(p, off)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
return at, nil
|
||||
}
|
||||
|
||||
func newWriterAt(
|
||||
ctx context.Context,
|
||||
wrAt io.WriterAt,
|
||||
progress ProgressTracker,
|
||||
taskInfo TaskInfo,
|
||||
) *ProgressWriterAt {
|
||||
return &ProgressWriterAt{
|
||||
ctx: ctx,
|
||||
progress: progress,
|
||||
downloaded: &atomic.Int64{},
|
||||
total: taskInfo.FileSize(),
|
||||
wrAt: wrAt,
|
||||
info: taskInfo,
|
||||
}
|
||||
}
|
||||
|
||||
type ProgressWriter struct {
|
||||
ctx context.Context
|
||||
wrAt io.Writer
|
||||
progress ProgressTracker
|
||||
downloaded *atomic.Int64
|
||||
total int64
|
||||
info TaskInfo
|
||||
}
|
||||
|
||||
func (w *ProgressWriter) Write(p []byte) (int, error) {
|
||||
at, err := w.wrAt.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
return at, nil
|
||||
}
|
||||
|
||||
func newWriter(
|
||||
ctx context.Context,
|
||||
wr io.Writer,
|
||||
progress ProgressTracker,
|
||||
taskInfo TaskInfo,
|
||||
) *ProgressWriter {
|
||||
return &ProgressWriter{
|
||||
ctx: ctx,
|
||||
progress: progress,
|
||||
downloaded: &atomic.Int64{},
|
||||
total: taskInfo.FileSize(),
|
||||
wrAt: wr,
|
||||
info: taskInfo,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user