refactor: refactor task logic for better scalability (#76)

* refactor: a big refactor. wip

* refactor: port handle file

* refactor: place all handlers

* fix: task info nil pointer

* feat: enhance task progress tracking and context management

* feat: cancel task

* feat: stream mode

* feat: silent mode

* feat: dir cmd

* refactor: remove unused old file

* feat: rule cmd

* feat: handle silent mode

* feat: batch task

* fix: batch task progress and temp file cleanup

* refactor: update file creation and cleanup methods for better resource management

* feat: add save command with silent mode handling

* feat: message link

* feat: update message prompts to include file count in storage selection

* feat: slient save links

* refactor: reduce dup code

* feat: rule type

* feat: chose dir

* feat: refactor file handling and storage rules, improve error handling and logging

* feat: rule mode

* feat: telegraph pics

* fix: tphpics nil pointer and inaccurate dirpath

* feat: silent save telegraph

* feat: add suffix to avoid file overwrite

* feat: new storage telegram

* chore: tidy go mod
This commit is contained in:
Krau
2025-06-15 23:57:49 +08:00
committed by GitHub
parent 280745cae3
commit 900823cdb9
150 changed files with 5730 additions and 3923 deletions

122
core/batchtftask/execute.go Normal file
View File

@@ -0,0 +1,122 @@
package batchtftask
import (
"context"
"fmt"
"io"
"os"
"path"
"github.com/charmbracelet/log"
"github.com/duke-git/lancet/v2/retry"
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/key"
"golang.org/x/sync/errgroup"
)
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("batch_file[%s]", t.ID))
logger.Info("Starting batch file task")
t.Progress.OnStart(ctx, t)
workers := config.Cfg.Workers
eg, gctx := errgroup.WithContext(ctx)
eg.SetLimit(workers)
for _, elem := range t.Elems {
elem := elem
eg.Go(func() error {
if t.processing[elem.ID] != nil {
return fmt.Errorf("element with ID %s is already being processed", elem.ID)
}
t.processing[elem.ID] = &elem
defer func() {
delete(t.processing, elem.ID)
}()
return t.processElement(gctx, elem)
})
}
err := eg.Wait()
if err != nil {
logger.Errorf("Error during batch file processing: %v", err)
} else {
logger.Info("Batch file task completed successfully")
}
t.Progress.OnDone(ctx, t, err)
return err
}
func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.File.Name()))
if elem.stream {
pr, pw := io.Pipe()
defer pr.Close()
errg, uploadCtx := errgroup.WithContext(ctx)
errg.Go(func() error {
return elem.Storage.Save(uploadCtx, pr, elem.Path)
})
wr := ioutil.NewProgressWriter(pw, func(n int) {
t.downloaded.Add(int64(n))
t.Progress.OnProgress(ctx, t)
})
errg.Go(func() error {
logger.Info("Starting file download in stream mode")
_, err := tdler.NewDownloader(t.client, elem.File).Stream(uploadCtx, wr)
if closeErr := pw.CloseWithError(err); closeErr != nil {
logger.Errorf("Failed to close pipe writer: %v", closeErr)
}
return err
})
if err := errg.Wait(); err != nil {
return fmt.Errorf("failed to download file in stream mode: %w", err)
}
logger.Info("File downloaded successfully in stream mode")
return nil
}
logger.Info("Starting file download")
localFile, err := fsutil.CreateFile(elem.localPath)
if err != nil {
return fmt.Errorf("failed to create local file: %w", err)
}
defer func() {
if err := localFile.CloseAndRemove(); err != nil {
logger.Errorf("Failed to close local file: %v", err)
}
}()
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
t.downloaded.Add(int64(n))
t.Progress.OnProgress(ctx, t)
})
_, err = tdler.NewDownloader(t.client, elem.File).Parallel(ctx, wrAt)
if err != nil {
return fmt.Errorf("failed to download file: %w", err)
}
logger.Info("File downloaded successfully")
if path.Ext(elem.FileName()) == "" {
ext := fsutil.DetectFileExt(elem.localPath)
if ext != "" {
elem.Path = elem.Path + ext
}
}
var fileStat os.FileInfo
fileStat, err = os.Stat(elem.localPath)
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size())
err = retry.Retry(func() error {
var file *os.File
file, err = os.Open(elem.localPath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
if err = elem.Storage.Save(vctx, file, elem.Path); err != nil {
logger.Errorf("Failed to save file: %s, retrying...", err)
return err
}
return nil
}, retry.Context(vctx), retry.RetryTimes(uint(config.Cfg.Retry)))
return err
}

View File

@@ -0,0 +1,176 @@
package batchtftask
import (
"context"
"errors"
"fmt"
"strconv"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/duke-git/lancet/v2/slice"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
)
type ProgressTracker interface {
OnStart(ctx context.Context, info TaskInfo)
OnProgress(ctx context.Context, info TaskInfo)
OnDone(ctx context.Context, info TaskInfo, err error)
}
type Progress struct {
MessageID int
ChatID int64
start time.Time
lastUpdatePercent atomic.Int32
}
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
p.start = time.Now()
p.lastUpdatePercent.Store(0)
log.FromContext(ctx).Debugf("Batch task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("开始执行批量下载任务\n总大小: "),
styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) {
if !shouldUpdateProgress(info.TotalSize(), info.Downloaded(), int(p.lastUpdatePercent.Load())) {
return
}
percent := int((info.Downloaded() * 100) / info.TotalSize())
if p.lastUpdatePercent.Load() == int32(percent) {
return
}
p.lastUpdatePercent.Store(int32(percent))
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Downloaded(), info.TotalSize())
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在处理批量下载任务\n总大小: "),
styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())),
styling.Plain("\n正在处理:\n"),
func() styling.StyledTextOption {
var lines []string
for _, elem := range info.Processing() {
lines = append(lines, fmt.Sprintf(" - %s (%.2f MB)", elem.FileName(), float64(elem.FileSize())/(1024*1024)))
}
if len(lines) == 0 {
lines = append(lines, " - 无")
}
return styling.Plain(slice.Join(lines, "\n"))
}(),
styling.Plain("\n平均速度: "),
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(info.Downloaded(), p.start)/(1024*1024))),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", float64(info.Downloaded())/float64(info.TotalSize())*100)),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
if err != nil {
log.FromContext(ctx).Errorf("Batch task %s failed: %s", info.TaskID(), err)
} else {
log.FromContext(ctx).Debugf("Batch task %s completed successfully", info.TaskID())
}
entityBuilder := entity.Builder{}
var stylingErr error
if err != nil {
if errors.Is(err, context.Canceled) {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("任务已取消"),
)
} else {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("处理失败, 错误:\n "),
styling.Code(err.Error()),
)
}
} else {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("处理完成\n文件数: "),
styling.Code(strconv.Itoa(info.Count())),
styling.Plain("\n总大小: "),
styling.Code(fmt.Sprintf("%.2f MB", float64(info.TotalSize())/(1024*1024))),
)
}
if stylingErr != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
}
}
func NewProgressTracker(messageID int, chatID int64) ProgressTracker {
return &Progress{
MessageID: messageID,
ChatID: chatID,
}
}

94
core/batchtftask/task.go Normal file
View File

@@ -0,0 +1,94 @@
package batchtftask
import (
"context"
"fmt"
"path/filepath"
"sync/atomic"
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
type TaskElement struct {
ID string
Storage storage.Storage
Path string
File tfile.TGFile
localPath string
stream bool
}
type Task struct {
ID string
Ctx context.Context
Elems []TaskElement
Progress ProgressTracker
IgnoreErrors bool // if true, errors during processing will be ignored
downloaded atomic.Int64
client tdler.Client
totalSize int64
processing map[string]TaskElementInfo
failed map[string]error // errors for each element
}
func NewTaskElement(
stor storage.Storage,
path string,
file tfile.TGFile,
) (*TaskElement, error) {
id := xid.New().String()
_, ok := stor.(storage.StorageCannotStream)
if !config.Cfg.Stream || ok {
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
if err != nil {
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
}
return &TaskElement{
ID: id,
Storage: stor,
Path: path,
File: file,
localPath: cachePath,
}, nil
}
return &TaskElement{
ID: id,
Storage: stor,
Path: path,
File: file,
stream: true,
}, nil
}
func NewBatchTGFileTask(
id string,
ctx context.Context,
files []TaskElement,
client tdler.Client,
progress ProgressTracker,
ignoreErrors bool,
) *Task {
task := &Task{
ID: id,
Ctx: ctx,
client: client,
Elems: files,
Progress: progress,
downloaded: atomic.Int64{},
totalSize: func() int64 {
var total int64
for _, elem := range files {
total += elem.File.Size()
}
return total
}(),
processing: make(map[string]TaskElementInfo),
IgnoreErrors: ignoreErrors,
failed: make(map[string]error),
}
return task
}

View File

@@ -0,0 +1,56 @@
package batchtftask
type TaskElementInfo interface {
FileName() string
FileSize() int64
StoragePath() string
StorageName() string
}
func (e *TaskElement) FileName() string {
return e.File.Name()
}
func (e *TaskElement) FileSize() int64 {
return e.File.Size()
}
func (e *TaskElement) StoragePath() string {
return e.Path
}
func (e *TaskElement) StorageName() string {
return e.Storage.Name()
}
type TaskInfo interface {
TaskID() string
TotalSize() int64
Downloaded() int64
Count() int
Processing() []TaskElementInfo
}
func (t *Task) TaskID() string {
return t.ID
}
func (t *Task) TotalSize() int64 {
return t.totalSize
}
func (t *Task) Downloaded() int64 {
return t.downloaded.Load()
}
func (t *Task) Count() int {
return len(t.Elems)
}
func (t *Task) Processing() []TaskElementInfo {
processing := make([]TaskElementInfo, 0, len(t.Elems))
for _, elem := range t.processing {
processing = append(processing, elem)
}
return processing
}

32
core/batchtftask/utils.go Normal file
View File

@@ -0,0 +1,32 @@
package batchtftask
var progressUpdatesLevels = []struct {
size int64 // 文件大小阈值
stepPercent int // 每多少 % 更新一次
}{
{10 << 20, 100},
{50 << 20, 20},
{200 << 20, 10},
{500 << 20, 5},
}
func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool {
if total <= 0 || downloaded <= 0 {
return false
}
percent := int((downloaded * 100) / total)
if percent <= lastUpdatePercent {
return false
}
step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent
for _, lvl := range progressUpdatesLevels {
if total < lvl.size {
step = lvl.stepPercent
break
}
}
return percent >= lastUpdatePercent+step
}

View File

@@ -2,92 +2,62 @@ package core
import (
"context"
"errors"
"fmt"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/telegram/downloader"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/pkg/queue"
)
var Downloader *downloader.Downloader
var queueInstance *queue.TaskQueue[Exectable]
func init() {
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
type Exectable interface {
TaskID() string
Execute(ctx context.Context) error
}
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan struct{}) {
for {
semaphore <- struct{}{}
task := queue.GetTask()
common.Log.Debugf("Got task: %s", task.String())
switch task.Status {
case types.Pending:
common.Log.Infof("Processing task: %s", task.String())
if err := processPendingTask(task); err != nil {
task.Error = err
if errors.Is(err, context.Canceled) {
task.Status = types.Canceled
} else {
common.Log.Errorf("Failed to do task: %s", err)
task.Status = types.Failed
}
} else {
task.Status = types.Succeeded
}
queue.AddTask(task)
case types.Succeeded:
common.Log.Infof("Task succeeded: %s", task.String())
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
ID: task.ReplyMessageID,
})
}
case types.Failed:
common.Log.Errorf("Task failed: %s", task.String())
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID,
})
}
case types.Canceled:
common.Log.Infof("Task canceled: %s", task.String())
extCtx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Errorf("Context is not *ext.Context: %T", task.Ctx)
} else if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "任务已取消",
ID: task.ReplyMessageID,
})
}
default:
common.Log.Errorf("Unknown task status: %s", task.Status)
qtask, err := qe.Get()
if err != nil {
break // queue closed and empty
}
log.FromContext(ctx).Infof("Processing task: %s", qtask.ID)
task := qtask.Data
if err := task.Execute(qtask.Context()); err != nil {
log.FromContext(ctx).Errorf("Failed to execute task %s: %v", qtask.ID, err)
} else {
log.FromContext(ctx).Infof("Task %s completed successfully", qtask.ID)
}
qe.Done(qtask.ID)
<-semaphore
common.Log.Debugf("Task done: %s; status: %s", task.String(), task.Status)
queue.DoneTask(task)
}
}
func Run() {
common.Log.Info("Start processing tasks...")
func Run(ctx context.Context) {
log.FromContext(ctx).Info("Start processing tasks...")
semaphore := make(chan struct{}, config.Cfg.Workers)
for i := 0; i < config.Cfg.Workers; i++ {
go worker(queue.Queue, semaphore)
if queueInstance == nil {
queueInstance = queue.NewTaskQueue[Exectable]()
}
for range config.Cfg.Workers {
go worker(ctx, queueInstance, semaphore)
}
}
func AddTask(ctx context.Context, task Exectable) error {
return queueInstance.Add(queue.NewTask(ctx, task.TaskID(), task))
}
func CancelTask(ctx context.Context, id string) error {
err := queueInstance.CancelTask(id)
return err
}
func GetLength(ctx context.Context) int {
if queueInstance == nil {
return 0
}
return queueInstance.ActiveLength()
}

View File

@@ -1,291 +0,0 @@
package core
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"path"
"path/filepath"
"strings"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/userclient"
"golang.org/x/sync/errgroup"
)
func processPendingTask(task *types.Task) error {
common.Log.Infof("Start processing task: %s", task.String())
if task.FileName() == "" {
task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash())
}
taskStorage, storagePath, err := getStorageAndPathForTask(task)
if err != nil {
return err
}
if taskStorage == nil {
return fmt.Errorf("not found storage: %s", task.StorageName)
}
task.StoragePath = storagePath
ctx, ok := task.Ctx.(*ext.Context)
if !ok {
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
}
cancelCtx, cancel := context.WithCancel(ctx)
task.Cancel = cancel
if task.IsTelegraph {
return processTelegraph(ctx, cancelCtx, task, taskStorage)
}
if task.File.FileSize == 0 {
return processPhoto(task, taskStorage)
}
api := bot.Client.API()
if task.UseUserClient && userclient.UC != nil {
api = userclient.UC.API()
}
downloadBuilder := Downloader.Download(api, task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
notsupportStreamStorage, notsupportStream := taskStorage.(storage.StorageNotSupportStream)
cancelMarkUp := getCancelTaskMarkup(task)
if config.Cfg.Stream {
if !notsupportStream {
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
pr, pw := io.Pipe()
defer pr.Close()
task.StartTime = time.Now()
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback)
eg, uploadCtx := errgroup.WithContext(cancelCtx)
eg.Go(func() error {
return taskStorage.Save(uploadCtx, pr, task.StoragePath)
})
eg.Go(func() error {
_, err := downloadBuilder.Stream(uploadCtx, progressStream)
if closeErr := pw.CloseWithError(err); closeErr != nil {
common.Log.Errorf("Failed to close pipe writer: %v", closeErr)
}
return err
})
if err := eg.Wait(); err != nil {
return err
}
return nil
}
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
}
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
cacheDestPath, err = filepath.Abs(cacheDestPath)
if err != nil {
return fmt.Errorf("处理路径失败: %w", err)
}
if err := fileutil.CreateDir(filepath.Dir(cacheDestPath)); err != nil {
return fmt.Errorf("创建目录失败: %w", err)
}
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
}
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
if err != nil {
return fmt.Errorf("创建文件失败: %w", err)
}
defer dest.Close()
task.StartTime = time.Now()
_, err = downloadBuilder.Parallel(cancelCtx, dest)
if err != nil {
return fmt.Errorf("下载文件失败: %w", err)
}
defer cleanCacheFile(cacheDestPath)
fixTaskFileExt(task, cacheDestPath)
common.Log.Infof("Downloaded file: %s", cacheDestPath)
if task.ReplyMessageID != 0 {
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
ID: task.ReplyMessageID,
})
}
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
}
func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *types.Task, taskStorage storage.Storage) error {
if bot.TelegraphClient == nil {
return fmt.Errorf("telegraph client is not initialized")
}
tgphUrl := task.TelegraphURL
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
if tgphUrl == "" || tgphPath == "" {
return fmt.Errorf("invalid telegraph url")
}
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在下载 Telegraph \n文件夹: %s\n保存路径: %s",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
)
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在下载 Telegraph \n文件夹: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
); err != nil {
common.Log.Errorf("Failed to build entities: %s", err)
}
if task.ReplyMessageID != 0 {
extCtx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
}
resultCh := make(chan error)
go func() {
page, err := bot.TelegraphClient.GetPage(tgphPath, true)
if err != nil {
resultCh <- fmt.Errorf("获取 telegraph 页面失败: %w", err)
return
}
imgs := make([]string, 0)
for _, element := range page.Content {
var node telegraph.NodeElement
data, err := json.Marshal(element)
if err != nil {
common.Log.Errorf("Failed to marshal element: %s", err)
continue
}
err = json.Unmarshal(data, &node)
if err != nil {
common.Log.Errorf("Failed to unmarshal element: %s", err)
continue
}
if len(node.Children) != 0 {
for _, child := range node.Children {
imgs = append(imgs, getNodeImages(child)...)
}
}
if node.Tag == "img" {
if src, ok := node.Attrs["src"]; ok {
imgs = append(imgs, src)
}
}
}
if len(imgs) == 0 {
resultCh <- fmt.Errorf("没有找到图片")
return
}
hc := bot.TelegraphClient.HttpClient
eg, ectx := errgroup.WithContext(cancelCtx)
eg.SetLimit(config.Cfg.Workers) // TODO: use a new config field for this
for i, img := range imgs {
if strings.HasPrefix(img, "/file/") {
img = "https://telegra.ph" + img
}
eg.Go(func() error {
var lastErr error
for attempt := range config.Cfg.Retry {
if attempt > 0 {
retryDelay := time.Duration(attempt*attempt) * time.Second
select {
case <-ectx.Done():
return ectx.Err()
case <-time.After(retryDelay):
}
common.Log.Debugf("Retrying to download image %s (attempt %d)", img, attempt+1)
}
req, err := http.NewRequestWithContext(ectx, http.MethodGet, img, nil)
if err != nil {
lastErr = fmt.Errorf("创建请求失败: %w", err)
continue
}
resp, err := hc.Do(req)
if err != nil {
lastErr = fmt.Errorf("发送请求失败: %w", err)
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("请求图片失败: %s", resp.Status)
continue
}
targetPath := path.Join(task.StoragePath, fmt.Sprintf("%d%s", i+1, path.Ext(img)))
err = taskStorage.Save(ectx, resp.Body, targetPath)
if err != nil {
lastErr = fmt.Errorf("保存图片失败: %w", err)
continue
}
common.Log.Infof("Saved image: %s", targetPath)
return nil
}
return lastErr
})
}
if err := eg.Wait(); err != nil {
resultCh <- err
return
}
resultCh <- nil
}()
select {
case err := <-resultCh:
return err
case <-cancelCtx.Done():
return cancelCtx.Err()
}
}

View File

@@ -1,80 +0,0 @@
package core
import (
"reflect"
"testing"
"github.com/celestix/telegraph-go/v2"
)
func TestGetImgSrcs(t *testing.T) {
complexStructure := telegraph.NodeElement{
Tag: "div",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "figure",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image1.png",
},
},
telegraph.NodeElement{
Tag: "p",
Children: []telegraph.Node{
"A text node",
},
},
telegraph.NodeElement{
Tag: "figure",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image2.png",
},
},
},
},
},
},
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image3.png",
},
},
"text node",
telegraph.NodeElement{
Tag: "div",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "span",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image4.png",
},
},
},
},
},
},
},
}
expected := []string{
"https://example.com/image1.png",
"https://example.com/image2.png",
"https://example.com/image3.png",
"https://example.com/image4.png",
}
got := getNodeImages(complexStructure)
if !reflect.DeepEqual(expected, got) {
t.Errorf("expected %vgot %v", expected, got)
}
}

View File

@@ -1,110 +0,0 @@
package core
import (
"fmt"
"path"
"regexp"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func getStorageAndPathForTask(task *types.Task) (storage.Storage, string, error) {
user, err := dao.GetUserByChatID(task.UserID)
if err != nil {
return nil, "", fmt.Errorf("failed to get user by chat ID: %w", err)
}
if task.StoragePath == "" {
task.StoragePath = task.FileName()
}
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
if err != nil {
return nil, "", err
}
storagePath := taskStorage.JoinStoragePath(*task)
var ruleTaskStorage storage.Storage
var ruleStoragePath string
if user.ApplyRule && user.Rules != nil {
for _, rule := range user.Rules {
matchStorage, matchStoragePath := applyRule(&rule, *task)
if matchStorage != nil && matchStoragePath != "" {
ruleTaskStorage = matchStorage
ruleStoragePath = matchStoragePath
common.Log.Debugf("Rule matched: %s, %s", ruleTaskStorage.Name(), ruleStoragePath)
return ruleTaskStorage, ruleStoragePath, nil
}
}
}
if taskStorage.Exists(task.Ctx, storagePath) {
ext := path.Ext(task.FileName())
name := task.FileName()[:len(task.FileName())-len(ext)]
task.File.FileName = fmt.Sprintf("%s_%d%s", name, task.FileDBID, ext)
task.StoragePath = task.File.FileName
storagePath = taskStorage.JoinStoragePath(*task)
}
return taskStorage, storagePath, nil
}
func applyRule(rule *dao.Rule, task types.Task) (storage.Storage, string) {
var DirPath, StorageName string
switch rule.Type {
case string(types.RuleTypeFileNameRegex):
ruleRegex, err := regexp.Compile(rule.Data)
if err != nil {
common.Log.Errorf("failed to compile regex: %s", err)
return nil, ""
}
if !ruleRegex.MatchString(task.FileName()) {
return nil, ""
}
DirPath = rule.DirPath
StorageName = rule.StorageName
case string(types.RuleTypeMessageRegex):
ruleRegex, err := regexp.Compile(rule.Data)
if err != nil {
common.Log.Errorf("failed to compile regex: %s", err)
return nil, ""
}
ctx, ok := task.Ctx.(*ext.Context)
if !ok {
common.Log.Fatalf("context is not *ext.Context: %T", task.Ctx)
return nil, ""
}
msg, err := bot.GetTGMessage(ctx, task.FileChatID, task.FileMessageID)
if err != nil {
common.Log.Errorf("failed to get message: %s", err)
return nil, ""
}
if msg == nil {
return nil, ""
}
if !ruleRegex.MatchString(msg.GetMessage()) {
return nil, ""
}
DirPath = rule.DirPath
StorageName = rule.StorageName
default:
common.Log.Errorf("unknown rule type: %s", rule.Type)
return nil, ""
}
taskStorageName := func() string {
if StorageName == "" || StorageName == "CHOSEN" {
return task.StorageName
}
return StorageName
}()
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, taskStorageName)
if err != nil {
common.Log.Errorf("failed to get storage: %s", err)
return nil, ""
}
task.StoragePath = path.Join(DirPath, task.StoragePath)
return taskStorage, taskStorage.JoinStoragePath(task)
}

82
core/tftask/execute.go Normal file
View File

@@ -0,0 +1,82 @@
package tftask
import (
"context"
"fmt"
"os"
"path"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/key"
)
func (t *TGFileTask) Execute(ctx context.Context) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name()))
t.Progress.OnStart(ctx, t)
if t.stream {
return executeStream(ctx, t)
}
logger.Info("Starting file download")
localFile, err := fsutil.CreateFile(t.localPath)
if err != nil {
return fmt.Errorf("failed to create local file: %w", err)
}
defer func() {
if err := localFile.CloseAndRemove(); err != nil {
logger.Errorf("Failed to close local file: %v", err)
}
}()
wrAt := newWriterAt(ctx, localFile, t.Progress, t)
defer func() {
t.Progress.OnDone(ctx, t, err)
}()
_, err = tdler.NewDownloader(t.client, t.File).Parallel(ctx, wrAt)
if err != nil {
return fmt.Errorf("failed to download file: %w", err)
}
logger.Infof("File downloaded successfully")
if path.Ext(t.File.Name()) == "" {
ext := fsutil.DetectFileExt(t.localPath)
if ext != "" {
t.Path = t.Path + ext
}
}
var fileStat os.FileInfo
fileStat, err = os.Stat(t.localPath)
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, key.ContextKeyContentLength, fileStat.Size())
for i := range config.Cfg.Retry + 1 {
if err = vctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
}
var file *os.File
file, err = os.Open(t.localPath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
if err = t.Storage.Save(vctx, file, t.Path); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
logger.Errorf("Failed to save file: %s, retrying...", err)
select {
case <-vctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond):
}
continue
}
return nil
}
return fmt.Errorf("failed to save file after retries")
}

186
core/tftask/progress.go Normal file
View File

@@ -0,0 +1,186 @@
package tftask
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
)
type ProgressTracker interface {
OnStart(ctx context.Context, info TaskInfo)
OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64)
OnDone(ctx context.Context, info TaskInfo, err error)
}
type Progress struct {
MessageID int
ChatID int64
start time.Time
lastUpdatePercent atomic.Int32
}
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
p.start = time.Now()
p.lastUpdatePercent.Store(0)
log.FromContext(ctx).Debugf("Progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("开始下载\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
styling.Plain("\n文件大小: "),
styling.Code(fmt.Sprintf("%.2f MB", float64(info.FileSize())/(1024*1024))),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo, downloaded, total int64) {
if !shouldUpdateProgress(total, downloaded, int(p.lastUpdatePercent.Load())) {
return
}
percent := int32((downloaded * 100) / total)
if p.lastUpdatePercent.Load() == percent {
return
}
p.lastUpdatePercent.Store(percent)
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.FileName(), downloaded, total)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
styling.Plain("\n文件大小: "),
styling.Code(fmt.Sprintf("%.2f MB", float64(total)/(1024*1024))),
styling.Plain("\n平均速度: "),
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(downloaded, p.start)/(1024*1024))),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", float64(downloaded)/float64(total)*100)),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
if err != nil {
log.FromContext(ctx).Errorf("Progress error for file [%s]: %v", info.FileName(), err)
} else {
log.FromContext(ctx).Debugf("Progress done for file [%s]", info.FileName())
}
entityBuilder := entity.Builder{}
var stylingErr error
if err != nil {
if errors.Is(err, context.Canceled) {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("任务已取消\n文件名: "),
styling.Code(info.FileName()),
)
} else {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("下载失败\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n错误: "),
styling.Bold(err.Error()),
)
}
} else {
stylingErr = styling.Perform(&entityBuilder,
styling.Plain("下载完成\n文件名: "),
styling.Code(info.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
)
}
if stylingErr != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", stylingErr)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
}
}
type ProgressOption func(*Progress)
func NewProgressTrack(
messageID int,
chatID int64,
opts ...ProgressOption,
) ProgressTracker {
p := &Progress{
MessageID: messageID,
ChatID: chatID,
}
for _, opt := range opts {
opt(p)
}
return p
}

40
core/tftask/stream.go Normal file
View File

@@ -0,0 +1,40 @@
package tftask
import (
"context"
"fmt"
"io"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/common/tdler"
"golang.org/x/sync/errgroup"
)
func executeStream(ctx context.Context, task *TGFileTask) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name()))
pr, pw := io.Pipe()
defer pr.Close()
errg, uploadCtx := errgroup.WithContext(ctx)
errg.Go(func() error {
return task.Storage.Save(uploadCtx, pr, task.Path)
})
wr := newWriter(ctx, pw, task.Progress, task)
errg.Go(func() error {
logger.Info("Starting file download in stream mode")
_, err := tdler.NewDownloader(task.client, task.File).Stream(uploadCtx, wr)
if closeErr := pw.CloseWithError(err); closeErr != nil {
logger.Errorf("Failed to close pipe writer: %v", closeErr)
}
return err
})
var err error
defer func() {
task.Progress.OnDone(ctx, task, err)
}()
if err = errg.Wait(); err != nil {
return err
}
logger.Info("File downloaded successfully in stream mode")
return nil
}

29
core/tftask/taskinfo.go Normal file
View File

@@ -0,0 +1,29 @@
package tftask
type TaskInfo interface {
TaskID() string
FileName() string
FileSize() int64
StoragePath() string
StorageName() string
}
func (t *TGFileTask) TaskID() string {
return t.ID
}
func (t *TGFileTask) FileName() string {
return t.File.Name()
}
func (t *TGFileTask) FileSize() int64 {
return t.File.Size()
}
func (t *TGFileTask) StoragePath() string {
return t.Path
}
func (t *TGFileTask) StorageName() string {
return t.Storage.Name()
}

64
core/tftask/tftask.go Normal file
View File

@@ -0,0 +1,64 @@
package tftask
import (
"context"
"fmt"
"path/filepath"
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
)
type TGFileTask struct {
ID string
Ctx context.Context
File tfile.TGFile
Storage storage.Storage
Path string
Progress ProgressTracker
client tdler.Client
stream bool // true if the file should be downloaded in stream mode
localPath string
}
func NewTGFileTask(
id string,
ctx context.Context,
file tfile.TGFile,
client tdler.Client,
stor storage.Storage,
path string,
progress ProgressTracker,
) (*TGFileTask, error) {
_, ok := stor.(storage.StorageCannotStream)
if !config.Cfg.Stream || ok {
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
if err != nil {
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
}
tftask := &TGFileTask{
ID: id,
Ctx: ctx,
client: client,
File: file,
Storage: stor,
Path: path,
Progress: progress,
localPath: cachePath,
}
return tftask, nil
}
tfileTask := &TGFileTask{
ID: id,
Ctx: ctx,
client: client,
File: file,
Storage: stor,
Path: path,
Progress: progress,
stream: true,
}
return tfileTask, nil
}

32
core/tftask/util.go Normal file
View File

@@ -0,0 +1,32 @@
package tftask
var progressUpdatesLevels = []struct {
size int64 // 文件大小阈值
stepPercent int // 每多少 % 更新一次
}{
{10 << 20, 100},
{50 << 20, 20},
{200 << 20, 10},
{500 << 20, 5},
}
func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool {
if total <= 0 || downloaded <= 0 {
return false
}
percent := int((downloaded * 100) / total)
if percent <= lastUpdatePercent {
return false
}
step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent
for _, lvl := range progressUpdatesLevels {
if total < lvl.size {
step = lvl.stepPercent
break
}
}
return percent >= lastUpdatePercent+step
}

75
core/tftask/writer.go Normal file
View File

@@ -0,0 +1,75 @@
package tftask
import (
"context"
"io"
"sync/atomic"
)
type ProgressWriterAt struct {
ctx context.Context
wrAt io.WriterAt
progress ProgressTracker
downloaded *atomic.Int64
total int64
info TaskInfo
}
func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
at, err := w.wrAt.WriteAt(p, off)
if err != nil {
return 0, err
}
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
return at, nil
}
func newWriterAt(
ctx context.Context,
wrAt io.WriterAt,
progress ProgressTracker,
taskInfo TaskInfo,
) *ProgressWriterAt {
return &ProgressWriterAt{
ctx: ctx,
progress: progress,
downloaded: &atomic.Int64{},
total: taskInfo.FileSize(),
wrAt: wrAt,
info: taskInfo,
}
}
type ProgressWriter struct {
ctx context.Context
wrAt io.Writer
progress ProgressTracker
downloaded *atomic.Int64
total int64
info TaskInfo
}
func (w *ProgressWriter) Write(p []byte) (int, error) {
at, err := w.wrAt.Write(p)
if err != nil {
return 0, err
}
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
return at, nil
}
func newWriter(
ctx context.Context,
wr io.Writer,
progress ProgressTracker,
taskInfo TaskInfo,
) *ProgressWriter {
return &ProgressWriter{
ctx: ctx,
progress: progress,
downloaded: &atomic.Int64{},
total: taskInfo.FileSize(),
wrAt: wr,
info: taskInfo,
}
}

94
core/tphtask/execute.go Normal file
View File

@@ -0,0 +1,94 @@
package tphtask
import (
"context"
"fmt"
"io"
"path"
"path/filepath"
"github.com/charmbracelet/log"
"github.com/duke-git/lancet/v2/retry"
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
"github.com/krau/SaveAny-Bot/config"
"go.uber.org/multierr"
"golang.org/x/sync/errgroup"
)
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx)
logger.Infof("Starting Telegraph task %s", t.PhPath)
t.progress.OnStart(ctx, t)
eg, gctx := errgroup.WithContext(ctx)
eg.SetLimit(config.Cfg.Workers)
for i, pic := range t.Pics {
pic := pic
i := i
eg.Go(func() error {
err := t.processPic(gctx, pic, i)
if err != nil {
logger.Errorf("Error processing picture %s: %v", pic, err)
return fmt.Errorf("failed to process picture %s: %w", pic, err)
}
t.downloaded.Add(1)
t.progress.OnProgress(gctx, t)
return nil
})
}
err := eg.Wait()
if err != nil {
logger.Errorf("Error during Telegraph task execution: %v", err)
} else {
logger.Infof("Telegraph task %s completed successfully", t.PhPath)
}
t.progress.OnDone(ctx, t, err)
return err
}
func (t *Task) processPic(ctx context.Context, picUrl string, index int) error {
retryOpts := []retry.Option{
retry.Context(ctx),
retry.RetryTimes(uint(config.Cfg.Retry)),
}
var lastErr error
err := retry.Retry(func() error {
var body io.ReadCloser
body, lastErr = t.client.Download(ctx, picUrl)
if lastErr != nil {
lastErr = fmt.Errorf("failed to download picture %s: %w", picUrl, lastErr)
return lastErr
}
defer body.Close()
filename := fmt.Sprintf("%d%s", index+1, path.Ext(picUrl))
if t.cannotStream {
cacheFile, err := fsutil.CreateFile(filepath.Join(config.Cfg.Temp.BasePath,
fmt.Sprintf("tph_%s_%s", t.TaskID(), filename),
))
if err != nil {
lastErr = fmt.Errorf("failed to create cache file for picture %s: %w", filename, err)
return lastErr
}
defer func() {
if err := cacheFile.CloseAndRemove(); err != nil {
logger := log.FromContext(ctx)
logger.Errorf("Failed to close and remove cache file for picture %s: %v", filename, err)
}
}()
_, lastErr = io.Copy(cacheFile, body)
if lastErr != nil {
lastErr = fmt.Errorf("failed to copy picture %s to cache file: %w", filename, lastErr)
return lastErr
}
lastErr = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, filename))
} else {
lastErr = t.Stor.Save(ctx, body, path.Join(t.StorPath, filename))
}
if lastErr != nil {
lastErr = fmt.Errorf("failed to save picture %s: %w", filename, lastErr)
return lastErr
}
return nil
}, retryOpts...)
return multierr.Combine(err, lastErr)
}

150
core/tphtask/progress.go Normal file
View File

@@ -0,0 +1,150 @@
package tphtask
import (
"context"
"errors"
"fmt"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
)
type ProgressTracker interface {
OnStart(ctx context.Context, info TaskInfo)
OnProgress(ctx context.Context, info TaskInfo)
OnDone(ctx context.Context, info TaskInfo, err error)
}
type Progress struct {
MessageID int
ChatID int64
}
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
logger := log.FromContext(ctx)
logger.Debugf("Telegraph task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("开始下载Telegraph\n图片数量: "),
styling.Code(fmt.Sprintf("%d", info.TotalPics())),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) {
if !shouldUpdateProgress(info.Downloaded(), int64(info.TotalPics())) {
return
}
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Downloaded(), info.TotalPics())
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在下载\n当前进度: "),
styling.Code(fmt.Sprintf("%d/%d", info.Downloaded(), info.TotalPics())),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
return
}
}
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
logger := log.FromContext(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
logger.Infof("Telegraph task %s was canceled", info.TaskID())
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, &tg.MessagesEditMessageRequest{
ID: p.MessageID,
Message: fmt.Sprintf("处理已取消: %s", info.TaskID()),
})
}
} else {
logger.Errorf("Telegraph task %s failed: %s", info.TaskID(), err)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, &tg.MessagesEditMessageRequest{
ID: p.MessageID,
Message: fmt.Sprintf("处理失败: %s", err.Error()),
})
}
}
return
}
logger.Infof("Telegraph task %s completed successfully", info.TaskID())
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain("处理完成\n图片数量: "),
styling.Code(fmt.Sprintf("%d", info.TotalPics())),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())),
); err != nil {
logger.Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
}
}
func NewProgress(messageID int, chatID int64) *Progress {
return &Progress{
MessageID: messageID,
ChatID: chatID,
}
}

51
core/tphtask/task.go Normal file
View File

@@ -0,0 +1,51 @@
package tphtask
import (
"context"
"sync/atomic"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/storage"
)
type Task struct {
ID string
Ctx context.Context
PhPath string
Pics []string
Stor storage.Storage
StorPath string
client *telegraph.Client
progress ProgressTracker
cannotStream bool
totalpics int
downloaded atomic.Int64
}
func NewTask(
id string,
ctx context.Context,
phPath string,
pics []string,
stor storage.Storage,
storPath string,
client *telegraph.Client,
progress ProgressTracker,
) *Task {
_, cannotStream := stor.(storage.StorageCannotStream)
tphtask := &Task{
ID: id,
Ctx: ctx,
PhPath: phPath,
Pics: pics,
Stor: stor,
StorPath: storPath,
client: client,
progress: progress,
cannotStream: cannotStream,
totalpics: len(pics),
downloaded: atomic.Int64{},
}
return tphtask
}

34
core/tphtask/taskinfo.go Normal file
View File

@@ -0,0 +1,34 @@
package tphtask
type TaskInfo interface {
TaskID() string
Phpath() string
TotalPics() int
Downloaded() int64
StorageName() string
StoragePath() string
}
func (t *Task) TaskID() string {
return t.ID
}
func (t *Task) Phpath() string {
return t.PhPath
}
func (t *Task) TotalPics() int {
return t.totalpics
}
func (t *Task) Downloaded() int64 {
return t.downloaded.Load()
}
func (t *Task) StorageName() string {
return t.Stor.Name()
}
func (t *Task) StoragePath() string {
return t.StorPath
}

13
core/tphtask/utils.go Normal file
View File

@@ -0,0 +1,13 @@
package tphtask
func shouldUpdateProgress(downloaded int64, total int64) bool {
if total <= 0 || downloaded <= 0 {
return false
}
step := int64(10)
if downloaded < step {
return downloaded == total
}
return downloaded%step == 0 || downloaded == total
}

View File

@@ -1,303 +0,0 @@
package core
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"path"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/userclient"
)
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
fileStat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
for i := 0; i <= config.Cfg.Retry; i++ {
if err := vctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
}
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
common.Log.Errorf("Failed to save file: %s, retrying...", err)
select {
case <-vctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond):
}
continue
}
return nil
}
return nil
}
func processPhoto(task *types.Task, taskStorage storage.Storage) error {
api := bot.Client.API()
if task.UseUserClient && userclient.UC != nil {
api = userclient.UC.API()
}
res, err := api.UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
Location: task.File.Location,
Offset: 0,
Limit: 1024 * 1024,
})
if err != nil {
return fmt.Errorf("failed to get file: %w", err)
}
result, ok := res.(*tg.UploadFile)
if !ok {
return fmt.Errorf("unexpected type %T", res)
}
common.Log.Infof("Downloaded photo: %s", task.FileName())
return taskStorage.Save(task.Ctx, bytes.NewReader(result.Bytes), task.StoragePath)
}
func cleanCacheFile(destPath string) {
if config.Cfg.Temp.CacheTTL > 0 {
common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
} else {
if err := os.Remove(destPath); err != nil {
common.Log.Errorf("Failed to purge file: %s", err)
}
}
}
// 获取进度需要更新的次数
func getProgressUpdateCount(fileSize int64) int {
updateCount := 5
if fileSize > 1024*1024*1000 {
updateCount = 50
} else if fileSize > 1024*1024*500 {
updateCount = 20
} else if fileSize > 1024*1024*200 {
updateCount = 10
}
return updateCount
}
func getSpeed(bytesRead int64, startTime time.Time) string {
if startTime.IsZero() {
return "0MB/s"
}
elapsed := time.Since(startTime)
speed := float64(bytesRead) / 1024 / 1024 / elapsed.Seconds()
return fmt.Sprintf("%.2fMB/s", speed)
}
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
getSpeed(bytesRead, startTime),
progress,
)
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
); err != nil {
common.Log.Errorf("Failed to build entities: %s", err)
return text, entities
}
return entityBuilder.Complete()
}
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
return func(bytesRead, contentLength int64) {
progress := float64(bytesRead) / float64(contentLength) * 100
common.Log.Tracef("Downloading %s: %.2f%%", task.String(), progress)
progressInt := int(progress)
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
return
}
if task.ReplyMessageID == 0 {
return
}
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
}
}
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
}
}
func fixTaskFileExt(task *types.Task, localFilePath string) {
if path.Ext(task.FileName()) == "" {
mimeType, err := mimetype.DetectFile(localFilePath)
if err != nil {
common.Log.Errorf("Failed to detect mime type: %s", err)
} else {
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
}
}
}
func getTaskThreads(fileSize int64) int {
threads := 1
if fileSize > 1024*1024*100 {
threads = config.Cfg.Threads
} else if fileSize > 1024*1024*50 {
threads = config.Cfg.Threads / 2
}
return threads
}
type TaskLocalFile struct {
file *os.File
size int64
done int64
progressCallback func(bytesRead, contentLength int64)
callbackTimes int64
nextCallbackAt int64
callbackInterval int64
}
func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
return t.file.Read(p)
}
func (t *TaskLocalFile) Close() error {
return t.file.Close()
}
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
n, err := t.file.WriteAt(p, off)
if err != nil {
return n, err
}
t.done += int64(n)
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
t.progressCallback(t.done, t.size)
t.nextCallbackAt += t.callbackInterval
}
return n, nil
}
func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
var callbackInterval int64
callbackInterval = fileSize / 100
if callbackInterval == 0 {
callbackInterval = 1
}
return &TaskLocalFile{
file: file,
size: fileSize,
progressCallback: progressCallback,
callbackTimes: 100,
nextCallbackAt: callbackInterval,
callbackInterval: callbackInterval,
}, nil
}
type ProgressStream struct {
writer io.Writer
size int64
done int64
callback func(bytesRead, contentLength int64)
nextAt int64
interval int64
}
func (ps *ProgressStream) Write(p []byte) (n int, err error) {
n, err = ps.writer.Write(p)
if err != nil {
return n, err
}
ps.done += int64(n)
if ps.callback != nil && ps.done >= ps.nextAt {
ps.callback(ps.done, ps.size)
ps.nextAt += ps.interval
}
return n, nil
}
func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream {
var interval int64
interval = size / 100
if interval == 0 {
interval = 1
}
return &ProgressStream{
writer: writer,
size: size,
callback: callback,
nextAt: interval,
interval: interval,
}
}
func getNodeImages(node telegraph.Node) []string {
var srcs []string
var nodeElement telegraph.NodeElement
data, err := json.Marshal(node)
if err != nil {
return srcs
}
err = json.Unmarshal(data, &nodeElement)
if err != nil {
return srcs
}
if nodeElement.Tag == "img" {
if src, exists := nodeElement.Attrs["src"]; exists {
srcs = append(srcs, src)
}
}
for _, child := range nodeElement.Children {
srcs = append(srcs, getNodeImages(child)...)
}
return srcs
}