mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-28 02:31:34 +08:00
refactor: refactor task logic for better scalability (#76)
* refactor: a big refactor. wip * refactor: port handle file * refactor: place all handlers * fix: task info nil pointer * feat: enhance task progress tracking and context management * feat: cancel task * feat: stream mode * feat: silent mode * feat: dir cmd * refactor: remove unused old file * feat: rule cmd * feat: handle silent mode * feat: batch task * fix: batch task progress and temp file cleanup * refactor: update file creation and cleanup methods for better resource management * feat: add save command with silent mode handling * feat: message link * feat: update message prompts to include file count in storage selection * feat: slient save links * refactor: reduce dup code * feat: rule type * feat: chose dir * feat: refactor file handling and storage rules, improve error handling and logging * feat: rule mode * feat: telegraph pics * fix: tphpics nil pointer and inaccurate dirpath * feat: silent save telegraph * feat: add suffix to avoid file overwrite * feat: new storage telegram * chore: tidy go mod
This commit is contained in:
122
core/batchtftask/execute.go
Normal file
122
core/batchtftask/execute.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package batchtftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/duke-git/lancet/v2/retry"
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/key"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func (t *Task) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("batch_file[%s]", t.ID))
|
||||
logger.Info("Starting batch file task")
|
||||
t.Progress.OnStart(ctx, t)
|
||||
workers := config.Cfg.Workers
|
||||
eg, gctx := errgroup.WithContext(ctx)
|
||||
eg.SetLimit(workers)
|
||||
for _, elem := range t.Elems {
|
||||
elem := elem
|
||||
eg.Go(func() error {
|
||||
if t.processing[elem.ID] != nil {
|
||||
return fmt.Errorf("element with ID %s is already being processed", elem.ID)
|
||||
}
|
||||
t.processing[elem.ID] = &elem
|
||||
defer func() {
|
||||
delete(t.processing, elem.ID)
|
||||
}()
|
||||
return t.processElement(gctx, elem)
|
||||
})
|
||||
}
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
logger.Errorf("Error during batch file processing: %v", err)
|
||||
} else {
|
||||
logger.Info("Batch file task completed successfully")
|
||||
}
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.File.Name()))
|
||||
if elem.stream {
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
errg, uploadCtx := errgroup.WithContext(ctx)
|
||||
errg.Go(func() error {
|
||||
return elem.Storage.Save(uploadCtx, pr, elem.Path)
|
||||
})
|
||||
wr := ioutil.NewProgressWriter(pw, func(n int) {
|
||||
t.downloaded.Add(int64(n))
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
})
|
||||
errg.Go(func() error {
|
||||
logger.Info("Starting file download in stream mode")
|
||||
_, err := tdler.NewDownloader(t.client, elem.File).Stream(uploadCtx, wr)
|
||||
if closeErr := pw.CloseWithError(err); closeErr != nil {
|
||||
logger.Errorf("Failed to close pipe writer: %v", closeErr)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err := errg.Wait(); err != nil {
|
||||
return fmt.Errorf("failed to download file in stream mode: %w", err)
|
||||
}
|
||||
logger.Info("File downloaded successfully in stream mode")
|
||||
return nil
|
||||
}
|
||||
logger.Info("Starting file download")
|
||||
localFile, err := fsutil.CreateFile(elem.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := localFile.CloseAndRemove(); err != nil {
|
||||
logger.Errorf("Failed to close local file: %v", err)
|
||||
}
|
||||
}()
|
||||
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
|
||||
t.downloaded.Add(int64(n))
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
})
|
||||
_, err = tdler.NewDownloader(t.client, elem.File).Parallel(ctx, wrAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file: %w", err)
|
||||
}
|
||||
logger.Info("File downloaded successfully")
|
||||
if path.Ext(elem.FileName()) == "" {
|
||||
ext := fsutil.DetectFileExt(elem.localPath)
|
||||
if ext != "" {
|
||||
elem.Path = elem.Path + ext
|
||||
}
|
||||
}
|
||||
var fileStat os.FileInfo
|
||||
fileStat, err = os.Stat(elem.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file stat: %w", err)
|
||||
}
|
||||
vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size())
|
||||
err = retry.Retry(func() error {
|
||||
var file *os.File
|
||||
file, err = os.Open(elem.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
if err = elem.Storage.Save(vctx, file, elem.Path); err != nil {
|
||||
logger.Errorf("Failed to save file: %s, retrying...", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, retry.Context(vctx), retry.RetryTimes(uint(config.Cfg.Retry)))
|
||||
return err
|
||||
}
|
||||
176
core/batchtftask/progress.go
Normal file
176
core/batchtftask/progress.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package batchtftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
)
|
||||
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, info TaskInfo)
|
||||
OnProgress(ctx context.Context, info TaskInfo)
|
||||
OnDone(ctx context.Context, info TaskInfo, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
MessageID int
|
||||
ChatID int64
|
||||
start time.Time
|
||||
lastUpdatePercent atomic.Int32
|
||||
}
|
||||
|
||||
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
|
||||
p.start = time.Now()
|
||||
p.lastUpdatePercent.Store(0)
|
||||
log.FromContext(ctx).Debugf("Batch task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("开始执行批量下载任务\n总大小: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) {
|
||||
if !shouldUpdateProgress(info.TotalSize(), info.Downloaded(), int(p.lastUpdatePercent.Load())) {
|
||||
return
|
||||
}
|
||||
percent := int((info.Downloaded() * 100) / info.TotalSize())
|
||||
if p.lastUpdatePercent.Load() == int32(percent) {
|
||||
return
|
||||
}
|
||||
p.lastUpdatePercent.Store(int32(percent))
|
||||
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Downloaded(), info.TotalSize())
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在处理批量下载任务\n总大小: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())),
|
||||
styling.Plain("\n正在处理:\n"),
|
||||
func() styling.StyledTextOption {
|
||||
var lines []string
|
||||
for _, elem := range info.Processing() {
|
||||
lines = append(lines, fmt.Sprintf(" - %s (%.2f MB)", elem.FileName(), float64(elem.FileSize())/(1024*1024)))
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
lines = append(lines, " - 无")
|
||||
}
|
||||
return styling.Plain(slice.Join(lines, "\n"))
|
||||
}(),
|
||||
styling.Plain("\n平均速度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(info.Downloaded(), p.start)/(1024*1024))),
|
||||
styling.Plain("\n当前进度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", float64(info.Downloaded())/float64(info.TotalSize())*100)),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Errorf("Batch task %s failed: %s", info.TaskID(), err)
|
||||
} else {
|
||||
log.FromContext(ctx).Debugf("Batch task %s completed successfully", info.TaskID())
|
||||
}
|
||||
entityBuilder := entity.Builder{}
|
||||
var stylingErr error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("任务已取消"),
|
||||
)
|
||||
} else {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("处理失败, 错误:\n "),
|
||||
styling.Code(err.Error()),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("处理完成\n文件数: "),
|
||||
styling.Code(strconv.Itoa(info.Count())),
|
||||
styling.Plain("\n总大小: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB", float64(info.TotalSize())/(1024*1024))),
|
||||
)
|
||||
}
|
||||
|
||||
if stylingErr != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr)
|
||||
return
|
||||
}
|
||||
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
func NewProgressTracker(messageID int, chatID int64) ProgressTracker {
|
||||
return &Progress{
|
||||
MessageID: messageID,
|
||||
ChatID: chatID,
|
||||
}
|
||||
}
|
||||
94
core/batchtftask/task.go
Normal file
94
core/batchtftask/task.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package batchtftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
type TaskElement struct {
|
||||
ID string
|
||||
Storage storage.Storage
|
||||
Path string
|
||||
File tfile.TGFile
|
||||
localPath string
|
||||
stream bool
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
ID string
|
||||
Ctx context.Context
|
||||
Elems []TaskElement
|
||||
Progress ProgressTracker
|
||||
IgnoreErrors bool // if true, errors during processing will be ignored
|
||||
downloaded atomic.Int64
|
||||
client tdler.Client
|
||||
totalSize int64
|
||||
processing map[string]TaskElementInfo
|
||||
failed map[string]error // errors for each element
|
||||
}
|
||||
|
||||
func NewTaskElement(
|
||||
stor storage.Storage,
|
||||
path string,
|
||||
file tfile.TGFile,
|
||||
) (*TaskElement, error) {
|
||||
id := xid.New().String()
|
||||
_, ok := stor.(storage.StorageCannotStream)
|
||||
if !config.Cfg.Stream || ok {
|
||||
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
|
||||
}
|
||||
return &TaskElement{
|
||||
ID: id,
|
||||
Storage: stor,
|
||||
Path: path,
|
||||
File: file,
|
||||
localPath: cachePath,
|
||||
}, nil
|
||||
}
|
||||
return &TaskElement{
|
||||
ID: id,
|
||||
Storage: stor,
|
||||
Path: path,
|
||||
File: file,
|
||||
stream: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewBatchTGFileTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
files []TaskElement,
|
||||
client tdler.Client,
|
||||
progress ProgressTracker,
|
||||
ignoreErrors bool,
|
||||
) *Task {
|
||||
task := &Task{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
client: client,
|
||||
Elems: files,
|
||||
Progress: progress,
|
||||
downloaded: atomic.Int64{},
|
||||
totalSize: func() int64 {
|
||||
var total int64
|
||||
for _, elem := range files {
|
||||
total += elem.File.Size()
|
||||
}
|
||||
return total
|
||||
}(),
|
||||
processing: make(map[string]TaskElementInfo),
|
||||
IgnoreErrors: ignoreErrors,
|
||||
failed: make(map[string]error),
|
||||
}
|
||||
return task
|
||||
}
|
||||
56
core/batchtftask/taskinfo.go
Normal file
56
core/batchtftask/taskinfo.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package batchtftask
|
||||
|
||||
type TaskElementInfo interface {
|
||||
FileName() string
|
||||
FileSize() int64
|
||||
StoragePath() string
|
||||
StorageName() string
|
||||
}
|
||||
|
||||
func (e *TaskElement) FileName() string {
|
||||
return e.File.Name()
|
||||
}
|
||||
|
||||
func (e *TaskElement) FileSize() int64 {
|
||||
return e.File.Size()
|
||||
}
|
||||
|
||||
func (e *TaskElement) StoragePath() string {
|
||||
return e.Path
|
||||
}
|
||||
|
||||
func (e *TaskElement) StorageName() string {
|
||||
return e.Storage.Name()
|
||||
}
|
||||
|
||||
type TaskInfo interface {
|
||||
TaskID() string
|
||||
TotalSize() int64
|
||||
Downloaded() int64
|
||||
Count() int
|
||||
Processing() []TaskElementInfo
|
||||
}
|
||||
|
||||
func (t *Task) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func (t *Task) TotalSize() int64 {
|
||||
return t.totalSize
|
||||
}
|
||||
|
||||
func (t *Task) Downloaded() int64 {
|
||||
return t.downloaded.Load()
|
||||
}
|
||||
|
||||
func (t *Task) Count() int {
|
||||
return len(t.Elems)
|
||||
}
|
||||
|
||||
func (t *Task) Processing() []TaskElementInfo {
|
||||
processing := make([]TaskElementInfo, 0, len(t.Elems))
|
||||
for _, elem := range t.processing {
|
||||
processing = append(processing, elem)
|
||||
}
|
||||
return processing
|
||||
}
|
||||
32
core/batchtftask/utils.go
Normal file
32
core/batchtftask/utils.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package batchtftask
|
||||
|
||||
var progressUpdatesLevels = []struct {
|
||||
size int64 // 文件大小阈值
|
||||
stepPercent int // 每多少 % 更新一次
|
||||
}{
|
||||
{10 << 20, 100},
|
||||
{50 << 20, 20},
|
||||
{200 << 20, 10},
|
||||
{500 << 20, 5},
|
||||
}
|
||||
|
||||
func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool {
|
||||
if total <= 0 || downloaded <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
percent := int((downloaded * 100) / total)
|
||||
if percent <= lastUpdatePercent {
|
||||
return false
|
||||
}
|
||||
|
||||
step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent
|
||||
for _, lvl := range progressUpdatesLevels {
|
||||
if total < lvl.size {
|
||||
step = lvl.stepPercent
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return percent >= lastUpdatePercent+step
|
||||
}
|
||||
112
core/core.go
112
core/core.go
@@ -2,92 +2,62 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/telegram/downloader"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/krau/SaveAny-Bot/pkg/queue"
|
||||
)
|
||||
|
||||
var Downloader *downloader.Downloader
|
||||
var queueInstance *queue.TaskQueue[Exectable]
|
||||
|
||||
func init() {
|
||||
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
|
||||
type Exectable interface {
|
||||
TaskID() string
|
||||
Execute(ctx context.Context) error
|
||||
}
|
||||
|
||||
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan struct{}) {
|
||||
for {
|
||||
semaphore <- struct{}{}
|
||||
task := queue.GetTask()
|
||||
common.Log.Debugf("Got task: %s", task.String())
|
||||
|
||||
switch task.Status {
|
||||
case types.Pending:
|
||||
common.Log.Infof("Processing task: %s", task.String())
|
||||
if err := processPendingTask(task); err != nil {
|
||||
task.Error = err
|
||||
if errors.Is(err, context.Canceled) {
|
||||
task.Status = types.Canceled
|
||||
} else {
|
||||
common.Log.Errorf("Failed to do task: %s", err)
|
||||
task.Status = types.Failed
|
||||
}
|
||||
} else {
|
||||
task.Status = types.Succeeded
|
||||
}
|
||||
queue.AddTask(task)
|
||||
case types.Succeeded:
|
||||
common.Log.Infof("Task succeeded: %s", task.String())
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
}
|
||||
case types.Failed:
|
||||
common.Log.Errorf("Task failed: %s", task.String())
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "文件保存失败\n" + task.Error.Error(),
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
}
|
||||
case types.Canceled:
|
||||
common.Log.Infof("Task canceled: %s", task.String())
|
||||
extCtx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
|
||||
} else if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: "任务已取消",
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
}
|
||||
default:
|
||||
common.Log.Errorf("Unknown task status: %s", task.Status)
|
||||
qtask, err := qe.Get()
|
||||
if err != nil {
|
||||
break // queue closed and empty
|
||||
}
|
||||
log.FromContext(ctx).Infof("Processing task: %s", qtask.ID)
|
||||
task := qtask.Data
|
||||
if err := task.Execute(qtask.Context()); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to execute task %s: %v", qtask.ID, err)
|
||||
} else {
|
||||
log.FromContext(ctx).Infof("Task %s completed successfully", qtask.ID)
|
||||
}
|
||||
qe.Done(qtask.ID)
|
||||
<-semaphore
|
||||
common.Log.Debugf("Task done: %s; status: %s", task.String(), task.Status)
|
||||
queue.DoneTask(task)
|
||||
}
|
||||
}
|
||||
|
||||
func Run() {
|
||||
common.Log.Info("Start processing tasks...")
|
||||
func Run(ctx context.Context) {
|
||||
log.FromContext(ctx).Info("Start processing tasks...")
|
||||
semaphore := make(chan struct{}, config.Cfg.Workers)
|
||||
for i := 0; i < config.Cfg.Workers; i++ {
|
||||
go worker(queue.Queue, semaphore)
|
||||
if queueInstance == nil {
|
||||
queueInstance = queue.NewTaskQueue[Exectable]()
|
||||
}
|
||||
for range config.Cfg.Workers {
|
||||
go worker(ctx, queueInstance, semaphore)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func AddTask(ctx context.Context, task Exectable) error {
|
||||
return queueInstance.Add(queue.NewTask(ctx, task.TaskID(), task))
|
||||
}
|
||||
|
||||
func CancelTask(ctx context.Context, id string) error {
|
||||
err := queueInstance.CancelTask(id)
|
||||
return err
|
||||
}
|
||||
|
||||
func GetLength(ctx context.Context) int {
|
||||
if queueInstance == nil {
|
||||
return 0
|
||||
}
|
||||
return queueInstance.ActiveLength()
|
||||
}
|
||||
|
||||
291
core/download.go
291
core/download.go
@@ -1,291 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
"github.com/duke-git/lancet/v2/fileutil"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/krau/SaveAny-Bot/userclient"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func processPendingTask(task *types.Task) error {
|
||||
common.Log.Infof("Start processing task: %s", task.String())
|
||||
|
||||
if task.FileName() == "" {
|
||||
task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash())
|
||||
}
|
||||
|
||||
taskStorage, storagePath, err := getStorageAndPathForTask(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if taskStorage == nil {
|
||||
return fmt.Errorf("not found storage: %s", task.StorageName)
|
||||
}
|
||||
task.StoragePath = storagePath
|
||||
|
||||
ctx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
|
||||
}
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
task.Cancel = cancel
|
||||
|
||||
if task.IsTelegraph {
|
||||
return processTelegraph(ctx, cancelCtx, task, taskStorage)
|
||||
}
|
||||
|
||||
if task.File.FileSize == 0 {
|
||||
return processPhoto(task, taskStorage)
|
||||
}
|
||||
api := bot.Client.API()
|
||||
if task.UseUserClient && userclient.UC != nil {
|
||||
api = userclient.UC.API()
|
||||
}
|
||||
downloadBuilder := Downloader.Download(api, task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
|
||||
|
||||
notsupportStreamStorage, notsupportStream := taskStorage.(storage.StorageNotSupportStream)
|
||||
cancelMarkUp := getCancelTaskMarkup(task)
|
||||
|
||||
if config.Cfg.Stream {
|
||||
if !notsupportStream {
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
|
||||
task.StartTime = time.Now()
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
|
||||
progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback)
|
||||
|
||||
eg, uploadCtx := errgroup.WithContext(cancelCtx)
|
||||
|
||||
eg.Go(func() error {
|
||||
return taskStorage.Save(uploadCtx, pr, task.StoragePath)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
_, err := downloadBuilder.Stream(uploadCtx, progressStream)
|
||||
if closeErr := pw.CloseWithError(err); closeErr != nil {
|
||||
common.Log.Errorf("Failed to close pipe writer: %v", closeErr)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err := eg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
|
||||
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
|
||||
cacheDestPath, err = filepath.Abs(cacheDestPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理路径失败: %w", err)
|
||||
}
|
||||
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
|
||||
return fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: cancelMarkUp,
|
||||
})
|
||||
}
|
||||
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建文件失败: %w", err)
|
||||
}
|
||||
defer dest.Close()
|
||||
task.StartTime = time.Now()
|
||||
_, err = downloadBuilder.Parallel(cancelCtx, dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("下载文件失败: %w", err)
|
||||
}
|
||||
defer cleanCacheFile(cacheDestPath)
|
||||
|
||||
fixTaskFileExt(task, cacheDestPath)
|
||||
|
||||
common.Log.Infof("Downloaded file: %s", cacheDestPath)
|
||||
if task.ReplyMessageID != 0 {
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
}
|
||||
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
|
||||
}
|
||||
|
||||
func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *types.Task, taskStorage storage.Storage) error {
|
||||
if bot.TelegraphClient == nil {
|
||||
return fmt.Errorf("telegraph client is not initialized")
|
||||
}
|
||||
tgphUrl := task.TelegraphURL
|
||||
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
|
||||
if tgphUrl == "" || tgphPath == "" {
|
||||
return fmt.Errorf("invalid telegraph url")
|
||||
}
|
||||
entityBuilder := entity.Builder{}
|
||||
text := fmt.Sprintf("正在下载 Telegraph \n文件夹: %s\n保存路径: %s",
|
||||
task.FileName(),
|
||||
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
||||
)
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在下载 Telegraph \n文件夹: "),
|
||||
styling.Code(task.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
||||
); err != nil {
|
||||
common.Log.Errorf("Failed to build entities: %s", err)
|
||||
}
|
||||
|
||||
if task.ReplyMessageID != 0 {
|
||||
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
}
|
||||
|
||||
resultCh := make(chan error)
|
||||
go func() {
|
||||
page, err := bot.TelegraphClient.GetPage(tgphPath, true)
|
||||
if err != nil {
|
||||
resultCh <- fmt.Errorf("获取 telegraph 页面失败: %w", err)
|
||||
return
|
||||
}
|
||||
imgs := make([]string, 0)
|
||||
for _, element := range page.Content {
|
||||
var node telegraph.NodeElement
|
||||
data, err := json.Marshal(element)
|
||||
if err != nil {
|
||||
common.Log.Errorf("Failed to marshal element: %s", err)
|
||||
continue
|
||||
}
|
||||
err = json.Unmarshal(data, &node)
|
||||
if err != nil {
|
||||
common.Log.Errorf("Failed to unmarshal element: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(node.Children) != 0 {
|
||||
for _, child := range node.Children {
|
||||
imgs = append(imgs, getNodeImages(child)...)
|
||||
}
|
||||
}
|
||||
|
||||
if node.Tag == "img" {
|
||||
if src, ok := node.Attrs["src"]; ok {
|
||||
imgs = append(imgs, src)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if len(imgs) == 0 {
|
||||
resultCh <- fmt.Errorf("没有找到图片")
|
||||
return
|
||||
}
|
||||
hc := bot.TelegraphClient.HttpClient
|
||||
eg, ectx := errgroup.WithContext(cancelCtx)
|
||||
eg.SetLimit(config.Cfg.Workers) // TODO: use a new config field for this
|
||||
for i, img := range imgs {
|
||||
if strings.HasPrefix(img, "/file/") {
|
||||
img = "https://telegra.ph" + img
|
||||
}
|
||||
eg.Go(func() error {
|
||||
var lastErr error
|
||||
for attempt := range config.Cfg.Retry {
|
||||
if attempt > 0 {
|
||||
retryDelay := time.Duration(attempt*attempt) * time.Second
|
||||
select {
|
||||
case <-ectx.Done():
|
||||
return ectx.Err()
|
||||
case <-time.After(retryDelay):
|
||||
}
|
||||
common.Log.Debugf("Retrying to download image %s (attempt %d)", img, attempt+1)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ectx, http.MethodGet, img, nil)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("发送请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("请求图片失败: %s", resp.Status)
|
||||
continue
|
||||
}
|
||||
targetPath := path.Join(task.StoragePath, fmt.Sprintf("%d%s", i+1, path.Ext(img)))
|
||||
err = taskStorage.Save(ectx, resp.Body, targetPath)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("保存图片失败: %w", err)
|
||||
continue
|
||||
}
|
||||
common.Log.Infof("Saved image: %s", targetPath)
|
||||
return nil
|
||||
}
|
||||
return lastErr
|
||||
})
|
||||
}
|
||||
if err := eg.Wait(); err != nil {
|
||||
resultCh <- err
|
||||
return
|
||||
}
|
||||
resultCh <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
return err
|
||||
case <-cancelCtx.Done():
|
||||
return cancelCtx.Err()
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
)
|
||||
|
||||
func TestGetImgSrcs(t *testing.T) {
|
||||
complexStructure := telegraph.NodeElement{
|
||||
Tag: "div",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "figure",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image1.png",
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "p",
|
||||
Children: []telegraph.Node{
|
||||
"A text node",
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "figure",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image2.png",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image3.png",
|
||||
},
|
||||
},
|
||||
"text node",
|
||||
telegraph.NodeElement{
|
||||
Tag: "div",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "span",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image4.png",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
"https://example.com/image1.png",
|
||||
"https://example.com/image2.png",
|
||||
"https://example.com/image3.png",
|
||||
"https://example.com/image4.png",
|
||||
}
|
||||
|
||||
got := getNodeImages(complexStructure)
|
||||
|
||||
if !reflect.DeepEqual(expected, got) {
|
||||
t.Errorf("expected %v,got %v", expected, got)
|
||||
}
|
||||
}
|
||||
110
core/rule.go
110
core/rule.go
@@ -1,110 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"regexp"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func getStorageAndPathForTask(task *types.Task) (storage.Storage, string, error) {
|
||||
user, err := dao.GetUserByChatID(task.UserID)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to get user by chat ID: %w", err)
|
||||
}
|
||||
if task.StoragePath == "" {
|
||||
task.StoragePath = task.FileName()
|
||||
}
|
||||
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
storagePath := taskStorage.JoinStoragePath(*task)
|
||||
|
||||
var ruleTaskStorage storage.Storage
|
||||
var ruleStoragePath string
|
||||
if user.ApplyRule && user.Rules != nil {
|
||||
for _, rule := range user.Rules {
|
||||
matchStorage, matchStoragePath := applyRule(&rule, *task)
|
||||
if matchStorage != nil && matchStoragePath != "" {
|
||||
ruleTaskStorage = matchStorage
|
||||
ruleStoragePath = matchStoragePath
|
||||
common.Log.Debugf("Rule matched: %s, %s", ruleTaskStorage.Name(), ruleStoragePath)
|
||||
return ruleTaskStorage, ruleStoragePath, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if taskStorage.Exists(task.Ctx, storagePath) {
|
||||
ext := path.Ext(task.FileName())
|
||||
name := task.FileName()[:len(task.FileName())-len(ext)]
|
||||
task.File.FileName = fmt.Sprintf("%s_%d%s", name, task.FileDBID, ext)
|
||||
task.StoragePath = task.File.FileName
|
||||
storagePath = taskStorage.JoinStoragePath(*task)
|
||||
}
|
||||
|
||||
return taskStorage, storagePath, nil
|
||||
}
|
||||
|
||||
func applyRule(rule *dao.Rule, task types.Task) (storage.Storage, string) {
|
||||
var DirPath, StorageName string
|
||||
switch rule.Type {
|
||||
case string(types.RuleTypeFileNameRegex):
|
||||
ruleRegex, err := regexp.Compile(rule.Data)
|
||||
if err != nil {
|
||||
common.Log.Errorf("failed to compile regex: %s", err)
|
||||
return nil, ""
|
||||
}
|
||||
if !ruleRegex.MatchString(task.FileName()) {
|
||||
return nil, ""
|
||||
}
|
||||
DirPath = rule.DirPath
|
||||
StorageName = rule.StorageName
|
||||
case string(types.RuleTypeMessageRegex):
|
||||
ruleRegex, err := regexp.Compile(rule.Data)
|
||||
if err != nil {
|
||||
common.Log.Errorf("failed to compile regex: %s", err)
|
||||
return nil, ""
|
||||
}
|
||||
ctx, ok := task.Ctx.(*ext.Context)
|
||||
if !ok {
|
||||
common.Log.Fatalf("context is not *ext.Context: %T", task.Ctx)
|
||||
return nil, ""
|
||||
}
|
||||
msg, err := bot.GetTGMessage(ctx, task.FileChatID, task.FileMessageID)
|
||||
if err != nil {
|
||||
common.Log.Errorf("failed to get message: %s", err)
|
||||
return nil, ""
|
||||
}
|
||||
if msg == nil {
|
||||
return nil, ""
|
||||
}
|
||||
if !ruleRegex.MatchString(msg.GetMessage()) {
|
||||
return nil, ""
|
||||
}
|
||||
DirPath = rule.DirPath
|
||||
StorageName = rule.StorageName
|
||||
default:
|
||||
common.Log.Errorf("unknown rule type: %s", rule.Type)
|
||||
return nil, ""
|
||||
}
|
||||
taskStorageName := func() string {
|
||||
if StorageName == "" || StorageName == "CHOSEN" {
|
||||
return task.StorageName
|
||||
}
|
||||
return StorageName
|
||||
}()
|
||||
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, taskStorageName)
|
||||
if err != nil {
|
||||
common.Log.Errorf("failed to get storage: %s", err)
|
||||
return nil, ""
|
||||
}
|
||||
task.StoragePath = path.Join(DirPath, task.StoragePath)
|
||||
return taskStorage, taskStorage.JoinStoragePath(task)
|
||||
}
|
||||
82
core/tftask/execute.go
Normal file
82
core/tftask/execute.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/key"
|
||||
)
|
||||
|
||||
func (t *TGFileTask) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name()))
|
||||
t.Progress.OnStart(ctx, t)
|
||||
if t.stream {
|
||||
return executeStream(ctx, t)
|
||||
}
|
||||
|
||||
logger.Info("Starting file download")
|
||||
localFile, err := fsutil.CreateFile(t.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := localFile.CloseAndRemove(); err != nil {
|
||||
logger.Errorf("Failed to close local file: %v", err)
|
||||
}
|
||||
}()
|
||||
wrAt := newWriterAt(ctx, localFile, t.Progress, t)
|
||||
|
||||
defer func() {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}()
|
||||
_, err = tdler.NewDownloader(t.client, t.File).Parallel(ctx, wrAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file: %w", err)
|
||||
}
|
||||
logger.Infof("File downloaded successfully")
|
||||
if path.Ext(t.File.Name()) == "" {
|
||||
ext := fsutil.DetectFileExt(t.localPath)
|
||||
if ext != "" {
|
||||
t.Path = t.Path + ext
|
||||
}
|
||||
}
|
||||
var fileStat os.FileInfo
|
||||
fileStat, err = os.Stat(t.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file stat: %w", err)
|
||||
}
|
||||
vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size())
|
||||
for i := range config.Cfg.Retry + 1 {
|
||||
if err = vctx.Err(); err != nil {
|
||||
return fmt.Errorf("context canceled while saving file: %w", err)
|
||||
}
|
||||
var file *os.File
|
||||
file, err = os.Open(t.localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
if err = t.Storage.Save(vctx, file, t.Path); err != nil {
|
||||
if i == config.Cfg.Retry {
|
||||
return fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
logger.Errorf("Failed to save file: %s, retrying...", err)
|
||||
select {
|
||||
case <-vctx.Done():
|
||||
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
|
||||
case <-time.After(time.Duration(i*500) * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to save file after retries")
|
||||
|
||||
}
|
||||
186
core/tftask/progress.go
Normal file
186
core/tftask/progress.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
)
|
||||
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, info TaskInfo)
|
||||
OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64)
|
||||
OnDone(ctx context.Context, info TaskInfo, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
MessageID int
|
||||
ChatID int64
|
||||
start time.Time
|
||||
lastUpdatePercent atomic.Int32
|
||||
}
|
||||
|
||||
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
|
||||
p.start = time.Now()
|
||||
p.lastUpdatePercent.Store(0)
|
||||
log.FromContext(ctx).Debugf("Progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("开始下载\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
|
||||
styling.Plain("\n文件大小: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB", float64(info.FileSize())/(1024*1024))),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64) {
|
||||
if !shouldUpdateProgress(total, downloaded, int(p.lastUpdatePercent.Load())) {
|
||||
return
|
||||
}
|
||||
percent := int32((downloaded * 100) / total)
|
||||
if p.lastUpdatePercent.Load() == percent {
|
||||
return
|
||||
}
|
||||
p.lastUpdatePercent.Store(percent)
|
||||
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.FileName(), downloaded, total)
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在处理下载任务\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
|
||||
styling.Plain("\n文件大小: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB", float64(total)/(1024*1024))),
|
||||
styling.Plain("\n平均速度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(downloaded, p.start)/(1024*1024))),
|
||||
styling.Plain("\n当前进度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", float64(downloaded)/float64(total)*100)),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Errorf("Progress error for file [%s]: %v", info.FileName(), err)
|
||||
} else {
|
||||
log.FromContext(ctx).Debugf("Progress done for file [%s]", info.FileName())
|
||||
}
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
var stylingErr error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("任务已取消\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
)
|
||||
} else {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("下载失败\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n错误: "),
|
||||
styling.Bold(err.Error()),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
stylingErr = styling.Perform(&entityBuilder,
|
||||
styling.Plain("下载完成\n文件名: "),
|
||||
styling.Code(info.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
|
||||
)
|
||||
}
|
||||
|
||||
if stylingErr != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr)
|
||||
return
|
||||
}
|
||||
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
type ProgressOption func(*Progress)
|
||||
|
||||
func NewProgressTrack(
|
||||
messageID int,
|
||||
chatID int64,
|
||||
opts ...ProgressOption,
|
||||
) ProgressTracker {
|
||||
p := &Progress{
|
||||
MessageID: messageID,
|
||||
ChatID: chatID,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
return p
|
||||
}
|
||||
40
core/tftask/stream.go
Normal file
40
core/tftask/stream.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func executeStream(ctx context.Context, task *TGFileTask) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name()))
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
errg, uploadCtx := errgroup.WithContext(ctx)
|
||||
errg.Go(func() error {
|
||||
return task.Storage.Save(uploadCtx, pr, task.Path)
|
||||
})
|
||||
wr := newWriter(ctx, pw, task.Progress, task)
|
||||
errg.Go(func() error {
|
||||
logger.Info("Starting file download in stream mode")
|
||||
_, err := tdler.NewDownloader(task.client, task.File).Stream(uploadCtx, wr)
|
||||
if closeErr := pw.CloseWithError(err); closeErr != nil {
|
||||
logger.Errorf("Failed to close pipe writer: %v", closeErr)
|
||||
}
|
||||
return err
|
||||
})
|
||||
var err error
|
||||
defer func() {
|
||||
task.Progress.OnDone(ctx, task, err)
|
||||
}()
|
||||
if err = errg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("File downloaded successfully in stream mode")
|
||||
return nil
|
||||
}
|
||||
29
core/tftask/taskinfo.go
Normal file
29
core/tftask/taskinfo.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package tftask
|
||||
|
||||
type TaskInfo interface {
|
||||
TaskID() string
|
||||
FileName() string
|
||||
FileSize() int64
|
||||
StoragePath() string
|
||||
StorageName() string
|
||||
}
|
||||
|
||||
func (t *TGFileTask) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func (t *TGFileTask) FileName() string {
|
||||
return t.File.Name()
|
||||
}
|
||||
|
||||
func (t *TGFileTask) FileSize() int64 {
|
||||
return t.File.Size()
|
||||
}
|
||||
|
||||
func (t *TGFileTask) StoragePath() string {
|
||||
return t.Path
|
||||
}
|
||||
|
||||
func (t *TGFileTask) StorageName() string {
|
||||
return t.Storage.Name()
|
||||
}
|
||||
64
core/tftask/tftask.go
Normal file
64
core/tftask/tftask.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
type TGFileTask struct {
|
||||
ID string
|
||||
Ctx context.Context
|
||||
File tfile.TGFile
|
||||
Storage storage.Storage
|
||||
Path string
|
||||
Progress ProgressTracker
|
||||
client tdler.Client
|
||||
stream bool // true if the file should be downloaded in stream mode
|
||||
localPath string
|
||||
}
|
||||
|
||||
func NewTGFileTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
file tfile.TGFile,
|
||||
client tdler.Client,
|
||||
stor storage.Storage,
|
||||
path string,
|
||||
progress ProgressTracker,
|
||||
) (*TGFileTask, error) {
|
||||
_, ok := stor.(storage.StorageCannotStream)
|
||||
if !config.Cfg.Stream || ok {
|
||||
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
|
||||
}
|
||||
tftask := &TGFileTask{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
client: client,
|
||||
File: file,
|
||||
Storage: stor,
|
||||
Path: path,
|
||||
Progress: progress,
|
||||
localPath: cachePath,
|
||||
}
|
||||
return tftask, nil
|
||||
}
|
||||
tfileTask := &TGFileTask{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
client: client,
|
||||
File: file,
|
||||
Storage: stor,
|
||||
Path: path,
|
||||
Progress: progress,
|
||||
stream: true,
|
||||
}
|
||||
return tfileTask, nil
|
||||
}
|
||||
32
core/tftask/util.go
Normal file
32
core/tftask/util.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package tftask
|
||||
|
||||
var progressUpdatesLevels = []struct {
|
||||
size int64 // 文件大小阈值
|
||||
stepPercent int // 每多少 % 更新一次
|
||||
}{
|
||||
{10 << 20, 100},
|
||||
{50 << 20, 20},
|
||||
{200 << 20, 10},
|
||||
{500 << 20, 5},
|
||||
}
|
||||
|
||||
func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool {
|
||||
if total <= 0 || downloaded <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
percent := int((downloaded * 100) / total)
|
||||
if percent <= lastUpdatePercent {
|
||||
return false
|
||||
}
|
||||
|
||||
step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent
|
||||
for _, lvl := range progressUpdatesLevels {
|
||||
if total < lvl.size {
|
||||
step = lvl.stepPercent
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return percent >= lastUpdatePercent+step
|
||||
}
|
||||
75
core/tftask/writer.go
Normal file
75
core/tftask/writer.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package tftask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type ProgressWriterAt struct {
|
||||
ctx context.Context
|
||||
wrAt io.WriterAt
|
||||
progress ProgressTracker
|
||||
downloaded *atomic.Int64
|
||||
total int64
|
||||
info TaskInfo
|
||||
}
|
||||
|
||||
func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
|
||||
at, err := w.wrAt.WriteAt(p, off)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
return at, nil
|
||||
}
|
||||
|
||||
func newWriterAt(
|
||||
ctx context.Context,
|
||||
wrAt io.WriterAt,
|
||||
progress ProgressTracker,
|
||||
taskInfo TaskInfo,
|
||||
) *ProgressWriterAt {
|
||||
return &ProgressWriterAt{
|
||||
ctx: ctx,
|
||||
progress: progress,
|
||||
downloaded: &atomic.Int64{},
|
||||
total: taskInfo.FileSize(),
|
||||
wrAt: wrAt,
|
||||
info: taskInfo,
|
||||
}
|
||||
}
|
||||
|
||||
type ProgressWriter struct {
|
||||
ctx context.Context
|
||||
wrAt io.Writer
|
||||
progress ProgressTracker
|
||||
downloaded *atomic.Int64
|
||||
total int64
|
||||
info TaskInfo
|
||||
}
|
||||
|
||||
func (w *ProgressWriter) Write(p []byte) (int, error) {
|
||||
at, err := w.wrAt.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
return at, nil
|
||||
}
|
||||
|
||||
func newWriter(
|
||||
ctx context.Context,
|
||||
wr io.Writer,
|
||||
progress ProgressTracker,
|
||||
taskInfo TaskInfo,
|
||||
) *ProgressWriter {
|
||||
return &ProgressWriter{
|
||||
ctx: ctx,
|
||||
progress: progress,
|
||||
downloaded: &atomic.Int64{},
|
||||
total: taskInfo.FileSize(),
|
||||
wrAt: wr,
|
||||
info: taskInfo,
|
||||
}
|
||||
}
|
||||
94
core/tphtask/execute.go
Normal file
94
core/tphtask/execute.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package tphtask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/duke-git/lancet/v2/retry"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"go.uber.org/multierr"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func (t *Task) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Infof("Starting Telegraph task %s", t.PhPath)
|
||||
t.progress.OnStart(ctx, t)
|
||||
eg, gctx := errgroup.WithContext(ctx)
|
||||
eg.SetLimit(config.Cfg.Workers)
|
||||
for i, pic := range t.Pics {
|
||||
pic := pic
|
||||
i := i
|
||||
eg.Go(func() error {
|
||||
err := t.processPic(gctx, pic, i)
|
||||
if err != nil {
|
||||
logger.Errorf("Error processing picture %s: %v", pic, err)
|
||||
return fmt.Errorf("failed to process picture %s: %w", pic, err)
|
||||
}
|
||||
t.downloaded.Add(1)
|
||||
t.progress.OnProgress(gctx, t)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
logger.Errorf("Error during Telegraph task execution: %v", err)
|
||||
} else {
|
||||
logger.Infof("Telegraph task %s completed successfully", t.PhPath)
|
||||
}
|
||||
t.progress.OnDone(ctx, t, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Task) processPic(ctx context.Context, picUrl string, index int) error {
|
||||
retryOpts := []retry.Option{
|
||||
retry.Context(ctx),
|
||||
retry.RetryTimes(uint(config.Cfg.Retry)),
|
||||
}
|
||||
var lastErr error
|
||||
err := retry.Retry(func() error {
|
||||
var body io.ReadCloser
|
||||
body, lastErr = t.client.Download(ctx, picUrl)
|
||||
if lastErr != nil {
|
||||
lastErr = fmt.Errorf("failed to download picture %s: %w", picUrl, lastErr)
|
||||
return lastErr
|
||||
}
|
||||
defer body.Close()
|
||||
filename := fmt.Sprintf("%d%s", index+1, path.Ext(picUrl))
|
||||
if t.cannotStream {
|
||||
cacheFile, err := fsutil.CreateFile(filepath.Join(config.Cfg.Temp.BasePath,
|
||||
fmt.Sprintf("tph_%s_%s", t.TaskID(), filename),
|
||||
))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to create cache file for picture %s: %w", filename, err)
|
||||
return lastErr
|
||||
}
|
||||
defer func() {
|
||||
if err := cacheFile.CloseAndRemove(); err != nil {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Errorf("Failed to close and remove cache file for picture %s: %v", filename, err)
|
||||
}
|
||||
}()
|
||||
_, lastErr = io.Copy(cacheFile, body)
|
||||
if lastErr != nil {
|
||||
lastErr = fmt.Errorf("failed to copy picture %s to cache file: %w", filename, lastErr)
|
||||
return lastErr
|
||||
}
|
||||
lastErr = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, filename))
|
||||
} else {
|
||||
lastErr = t.Stor.Save(ctx, body, path.Join(t.StorPath, filename))
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
lastErr = fmt.Errorf("failed to save picture %s: %w", filename, lastErr)
|
||||
return lastErr
|
||||
}
|
||||
return nil
|
||||
}, retryOpts...)
|
||||
return multierr.Combine(err, lastErr)
|
||||
}
|
||||
150
core/tphtask/progress.go
Normal file
150
core/tphtask/progress.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package tphtask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
)
|
||||
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, info TaskInfo)
|
||||
OnProgress(ctx context.Context, info TaskInfo)
|
||||
OnDone(ctx context.Context, info TaskInfo, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
MessageID int
|
||||
ChatID int64
|
||||
}
|
||||
|
||||
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Debugf("Telegraph task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("开始下载Telegraph\n图片数量: "),
|
||||
styling.Code(fmt.Sprintf("%d", info.TotalPics())),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) {
|
||||
if !shouldUpdateProgress(info.Downloaded(), int64(info.TotalPics())) {
|
||||
return
|
||||
}
|
||||
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Downloaded(), info.TotalPics())
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在下载\n当前进度: "),
|
||||
styling.Code(fmt.Sprintf("%d/%d", info.Downloaded(), info.TotalPics())),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
|
||||
logger := log.FromContext(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Infof("Telegraph task %s was canceled", info.TaskID())
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
Message: fmt.Sprintf("处理已取消: %s", info.TaskID()),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
logger.Errorf("Telegraph task %s failed: %s", info.TaskID(), err)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
Message: fmt.Sprintf("处理失败: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("Telegraph task %s completed successfully", info.TaskID())
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("处理完成\n图片数量: "),
|
||||
styling.Code(fmt.Sprintf("%d", info.TotalPics())),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
|
||||
); err != nil {
|
||||
logger.Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
func NewProgress(messageID int, chatID int64) *Progress {
|
||||
return &Progress{
|
||||
MessageID: messageID,
|
||||
ChatID: chatID,
|
||||
}
|
||||
}
|
||||
51
core/tphtask/task.go
Normal file
51
core/tphtask/task.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package tphtask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
ID string
|
||||
Ctx context.Context
|
||||
PhPath string
|
||||
Pics []string
|
||||
Stor storage.Storage
|
||||
StorPath string
|
||||
client *telegraph.Client
|
||||
progress ProgressTracker
|
||||
|
||||
cannotStream bool
|
||||
totalpics int
|
||||
downloaded atomic.Int64
|
||||
}
|
||||
|
||||
func NewTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
phPath string,
|
||||
pics []string,
|
||||
stor storage.Storage,
|
||||
storPath string,
|
||||
client *telegraph.Client,
|
||||
progress ProgressTracker,
|
||||
) *Task {
|
||||
_, cannotStream := stor.(storage.StorageCannotStream)
|
||||
tphtask := &Task{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
PhPath: phPath,
|
||||
Pics: pics,
|
||||
Stor: stor,
|
||||
StorPath: storPath,
|
||||
client: client,
|
||||
progress: progress,
|
||||
cannotStream: cannotStream,
|
||||
totalpics: len(pics),
|
||||
downloaded: atomic.Int64{},
|
||||
}
|
||||
return tphtask
|
||||
}
|
||||
34
core/tphtask/taskinfo.go
Normal file
34
core/tphtask/taskinfo.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package tphtask
|
||||
|
||||
type TaskInfo interface {
|
||||
TaskID() string
|
||||
Phpath() string
|
||||
TotalPics() int
|
||||
Downloaded() int64
|
||||
StorageName() string
|
||||
StoragePath() string
|
||||
}
|
||||
|
||||
func (t *Task) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func (t *Task) Phpath() string {
|
||||
return t.PhPath
|
||||
}
|
||||
|
||||
func (t *Task) TotalPics() int {
|
||||
return t.totalpics
|
||||
}
|
||||
|
||||
func (t *Task) Downloaded() int64 {
|
||||
return t.downloaded.Load()
|
||||
}
|
||||
|
||||
func (t *Task) StorageName() string {
|
||||
return t.Stor.Name()
|
||||
}
|
||||
|
||||
func (t *Task) StoragePath() string {
|
||||
return t.StorPath
|
||||
}
|
||||
13
core/tphtask/utils.go
Normal file
13
core/tphtask/utils.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package tphtask
|
||||
|
||||
func shouldUpdateProgress(downloaded int64, total int64) bool {
|
||||
if total <= 0 || downloaded <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
step := int64(10)
|
||||
if downloaded < step {
|
||||
return downloaded == total
|
||||
}
|
||||
return downloaded%step == 0 || downloaded == total
|
||||
}
|
||||
303
core/utils.go
303
core/utils.go
@@ -1,303 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/krau/SaveAny-Bot/userclient"
|
||||
)
|
||||
|
||||
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
|
||||
file, err := os.Open(cacheFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
fileStat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file stat: %w", err)
|
||||
}
|
||||
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
|
||||
for i := 0; i <= config.Cfg.Retry; i++ {
|
||||
if err := vctx.Err(); err != nil {
|
||||
return fmt.Errorf("context canceled while saving file: %w", err)
|
||||
}
|
||||
file, err := os.Open(cacheFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
|
||||
if i == config.Cfg.Retry {
|
||||
return fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
common.Log.Errorf("Failed to save file: %s, retrying...", err)
|
||||
select {
|
||||
case <-vctx.Done():
|
||||
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
|
||||
case <-time.After(time.Duration(i*500) * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processPhoto(task *types.Task, taskStorage storage.Storage) error {
|
||||
api := bot.Client.API()
|
||||
if task.UseUserClient && userclient.UC != nil {
|
||||
api = userclient.UC.API()
|
||||
}
|
||||
res, err := api.UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
|
||||
Location: task.File.Location,
|
||||
Offset: 0,
|
||||
Limit: 1024 * 1024,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file: %w", err)
|
||||
}
|
||||
|
||||
result, ok := res.(*tg.UploadFile)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T", res)
|
||||
}
|
||||
|
||||
common.Log.Infof("Downloaded photo: %s", task.FileName())
|
||||
|
||||
return taskStorage.Save(task.Ctx, bytes.NewReader(result.Bytes), task.StoragePath)
|
||||
}
|
||||
|
||||
func cleanCacheFile(destPath string) {
|
||||
if config.Cfg.Temp.CacheTTL > 0 {
|
||||
common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
|
||||
} else {
|
||||
if err := os.Remove(destPath); err != nil {
|
||||
common.Log.Errorf("Failed to purge file: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取进度需要更新的次数
|
||||
func getProgressUpdateCount(fileSize int64) int {
|
||||
updateCount := 5
|
||||
if fileSize > 1024*1024*1000 {
|
||||
updateCount = 50
|
||||
} else if fileSize > 1024*1024*500 {
|
||||
updateCount = 20
|
||||
} else if fileSize > 1024*1024*200 {
|
||||
updateCount = 10
|
||||
}
|
||||
return updateCount
|
||||
}
|
||||
|
||||
func getSpeed(bytesRead int64, startTime time.Time) string {
|
||||
if startTime.IsZero() {
|
||||
return "0MB/s"
|
||||
}
|
||||
elapsed := time.Since(startTime)
|
||||
speed := float64(bytesRead) / 1024 / 1024 / elapsed.Seconds()
|
||||
return fmt.Sprintf("%.2fMB/s", speed)
|
||||
}
|
||||
|
||||
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
|
||||
entityBuilder := entity.Builder{}
|
||||
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
|
||||
task.FileName(),
|
||||
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
||||
getSpeed(bytesRead, startTime),
|
||||
progress,
|
||||
)
|
||||
var entities []tg.MessageEntityClass
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在处理下载任务\n文件名: "),
|
||||
styling.Code(task.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
||||
styling.Plain("\n平均速度: "),
|
||||
styling.Bold(getSpeed(bytesRead, task.StartTime)),
|
||||
styling.Plain("\n当前进度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
|
||||
); err != nil {
|
||||
common.Log.Errorf("Failed to build entities: %s", err)
|
||||
return text, entities
|
||||
}
|
||||
return entityBuilder.Complete()
|
||||
}
|
||||
|
||||
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
|
||||
return func(bytesRead, contentLength int64) {
|
||||
progress := float64(bytesRead) / float64(contentLength) * 100
|
||||
common.Log.Tracef("Downloading %s: %.2f%%", task.String(), progress)
|
||||
progressInt := int(progress)
|
||||
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
|
||||
return
|
||||
}
|
||||
if task.ReplyMessageID == 0 {
|
||||
return
|
||||
}
|
||||
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
ReplyMarkup: getCancelTaskMarkup(task),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
|
||||
return &tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
|
||||
}
|
||||
}
|
||||
|
||||
func fixTaskFileExt(task *types.Task, localFilePath string) {
|
||||
if path.Ext(task.FileName()) == "" {
|
||||
mimeType, err := mimetype.DetectFile(localFilePath)
|
||||
if err != nil {
|
||||
common.Log.Errorf("Failed to detect mime type: %s", err)
|
||||
} else {
|
||||
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
|
||||
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTaskThreads(fileSize int64) int {
|
||||
threads := 1
|
||||
if fileSize > 1024*1024*100 {
|
||||
threads = config.Cfg.Threads
|
||||
} else if fileSize > 1024*1024*50 {
|
||||
threads = config.Cfg.Threads / 2
|
||||
}
|
||||
return threads
|
||||
}
|
||||
|
||||
type TaskLocalFile struct {
|
||||
file *os.File
|
||||
size int64
|
||||
done int64
|
||||
progressCallback func(bytesRead, contentLength int64)
|
||||
callbackTimes int64
|
||||
nextCallbackAt int64
|
||||
callbackInterval int64
|
||||
}
|
||||
|
||||
func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
|
||||
return t.file.Read(p)
|
||||
}
|
||||
|
||||
func (t *TaskLocalFile) Close() error {
|
||||
return t.file.Close()
|
||||
}
|
||||
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
|
||||
n, err := t.file.WriteAt(p, off)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
t.done += int64(n)
|
||||
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
|
||||
t.progressCallback(t.done, t.size)
|
||||
t.nextCallbackAt += t.callbackInterval
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
var callbackInterval int64
|
||||
callbackInterval = fileSize / 100
|
||||
if callbackInterval == 0 {
|
||||
callbackInterval = 1
|
||||
}
|
||||
return &TaskLocalFile{
|
||||
file: file,
|
||||
size: fileSize,
|
||||
progressCallback: progressCallback,
|
||||
callbackTimes: 100,
|
||||
nextCallbackAt: callbackInterval,
|
||||
callbackInterval: callbackInterval,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ProgressStream struct {
|
||||
writer io.Writer
|
||||
size int64
|
||||
done int64
|
||||
callback func(bytesRead, contentLength int64)
|
||||
nextAt int64
|
||||
interval int64
|
||||
}
|
||||
|
||||
func (ps *ProgressStream) Write(p []byte) (n int, err error) {
|
||||
n, err = ps.writer.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
ps.done += int64(n)
|
||||
if ps.callback != nil && ps.done >= ps.nextAt {
|
||||
ps.callback(ps.done, ps.size)
|
||||
ps.nextAt += ps.interval
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream {
|
||||
var interval int64
|
||||
interval = size / 100
|
||||
if interval == 0 {
|
||||
interval = 1
|
||||
}
|
||||
return &ProgressStream{
|
||||
writer: writer,
|
||||
size: size,
|
||||
callback: callback,
|
||||
nextAt: interval,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
func getNodeImages(node telegraph.Node) []string {
|
||||
var srcs []string
|
||||
|
||||
var nodeElement telegraph.NodeElement
|
||||
data, err := json.Marshal(node)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
err = json.Unmarshal(data, &nodeElement)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
|
||||
if nodeElement.Tag == "img" {
|
||||
if src, exists := nodeElement.Attrs["src"]; exists {
|
||||
srcs = append(srcs, src)
|
||||
}
|
||||
}
|
||||
for _, child := range nodeElement.Children {
|
||||
srcs = append(srcs, getNodeImages(child)...)
|
||||
}
|
||||
return srcs
|
||||
}
|
||||
Reference in New Issue
Block a user