From c21ff7e499861b9e0e5f6e8609ad7284a0c39999 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:10:41 +0800 Subject: [PATCH] feat: add direct links download functionality - Implemented a new task type for handling direct links downloads. - Added command handler for downloading multiple links via /dl command. - Introduced progress tracking for direct link downloads. - Enhanced filename parsing to support various encoding scenarios. - Updated enums to include direct links as a task type. - Refactored existing task structures to accommodate new functionality. - Improved error handling and logging throughout the download process. --- client/bot/handlers/add_task.go | 4 +- client/bot/handlers/dl.go | 49 +++++ client/bot/handlers/register.go | 5 +- client/bot/handlers/save.go | 13 +- client/bot/handlers/utils/msgelem/storage.go | 2 + .../handlers/utils/shortcut/directlinks.go | 30 +++ common/utils/ioutil/writer.go | 6 +- core/tasks/batchtfile/execute.go | 2 +- core/tasks/batchtfile/task.go | 8 +- core/tasks/batchtfile/taskinfo.go | 4 +- core/tasks/directlinks/execute.go | 167 ++++++++++++++ core/tasks/directlinks/progress.go | 196 +++++++++++++++++ core/tasks/directlinks/task.go | 121 +++++++++++ core/tasks/directlinks/util.go | 205 ++++++++++++++++++ core/tasks/parsed/task.go | 6 +- go.sum | 12 - parsers/parser.go | 1 + pkg/enums/tasktype/tasktype.go | 2 +- pkg/enums/tasktype/tasktype_enum.go | 11 +- pkg/parser/parser.go | 2 +- pkg/tcbdata/data.go | 2 + 21 files changed, 804 insertions(+), 44 deletions(-) create mode 100644 client/bot/handlers/dl.go create mode 100644 client/bot/handlers/utils/shortcut/directlinks.go create mode 100644 core/tasks/directlinks/execute.go create mode 100644 core/tasks/directlinks/progress.go create mode 100644 core/tasks/directlinks/task.go create mode 100644 core/tasks/directlinks/util.go diff --git a/client/bot/handlers/add_task.go b/client/bot/handlers/add_task.go index dee3415..114ae08 100644 --- a/client/bot/handlers/add_task.go +++ b/client/bot/handlers/add_task.go @@ -80,8 +80,10 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error { dirPath = path.Join(dirPath, fsutil.NormalizePathname(data.ParsedItem.Title)) } shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID) + case tasktype.TaskTypeDirectlinks: + shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID) default: - log.FromContext(ctx).Errorf("Unsupported task type: %s", data.TaskType) + return fmt.Errorf("unexcept task type: %s", data.TaskType) } return dispatcher.EndGroups } diff --git a/client/bot/handlers/dl.go b/client/bot/handlers/dl.go new file mode 100644 index 0000000..3806c18 --- /dev/null +++ b/client/bot/handlers/dl.go @@ -0,0 +1,49 @@ +package handlers + +import ( + "fmt" + "net/url" + "strings" + + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/duke-git/lancet/v2/slice" + "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/tcbdata" + "github.com/krau/SaveAny-Bot/storage" +) + +func handleDlCmd(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + args := strings.Split(update.EffectiveMessage.Text, " ") + if len(args) < 2 { + ctx.Reply(update, ext.ReplyTextString("用法: /dl <链接1> <链接2> ..."), nil) + return nil + } + links := args[1:] + for i, link := range links { + links[i] = strings.TrimSpace(link) + u, err := url.Parse(link) + if err != nil || u.Scheme == "" || u.Host == "" { + logger.Warn("invaild link", link) + links[i] = "" + } + } + links = slice.Compact(links) + if len(links) == 0 { + ctx.Reply(update, ext.ReplyTextString("没有有效的链接可供下载"), nil) + return nil + } + markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{ + TaskType: tasktype.TaskTypeDirectlinks, + DirectLinks: links, + }) + if err != nil { + return err + } + ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("共 %d 个文件, 请选择存储位置", len(links))), &ext.ReplyOpts{ + Markup: markup, + }) + return nil +} diff --git a/client/bot/handlers/register.go b/client/bot/handlers/register.go index 45cf152..064622d 100644 --- a/client/bot/handlers/register.go +++ b/client/bot/handlers/register.go @@ -26,15 +26,16 @@ var CommandHandlers = []DescCommandHandler{ {"storage", "设置默认存储端", handleStorageCmd}, {"dir", "管理存储文件夹", handleDirCmd}, {"rule", "管理自动存储规则", handleRuleCmd}, + {"save", "保存文件", handleSilentMode(handleSaveCmd, handleSilentSaveReplied)}, + {"dl", "下载给定链接的文件", handleDlCmd}, {"watch", "监听聊天(UserBot)", handleWatchCmd}, {"unwatch", "取消监听聊天(UserBot)", handleUnwatchCmd}, {"lswatch", "列出监听的聊天(UserBot)", handleLswatchCmd}, - {"save", "保存文件", handleSilentMode(handleSaveCmd, handleSilentSaveReplied)}, {"config", "修改配置", handleConfigCmd}, {"fnametmpl", "设置文件命名模板", handleConfigFnameTmpl}, - {"update", "检查更新", handleUpdateCmd}, {"help", "显示帮助", handleHelpCmd}, {"parser", "管理解析器", handleParserCmd}, + {"update", "检查更新", handleUpdateCmd}, } func Register(disp dispatcher.Dispatcher) { diff --git a/client/bot/handlers/save.go b/client/bot/handlers/save.go index 4b8db1f..4f0e2d0 100644 --- a/client/bot/handlers/save.go +++ b/client/bot/handlers/save.go @@ -26,7 +26,7 @@ import ( func handleSaveCmd(ctx *ext.Context, update *ext.Update) error { logger := log.FromContext(ctx) - args := strings.Split(string(update.EffectiveMessage.Text), " ") + args := strings.Split(update.EffectiveMessage.Text, " ") if len(args) >= 3 { return handleBatchSave(ctx, update, args[1:]) } @@ -35,17 +35,6 @@ func handleSaveCmd(ctx *ext.Context, update *ext.Update) error { ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgSaveHelpText)), nil) return dispatcher.EndGroups } - // genFilename := func() string { - // if len(args) > 1 { - // return args[1] - // } - // filename := tgutil.GenFileNameFromMessage(*replyTo.Message) - // return filename - // }() - // option := tfile.WithNameIfEmpty(genFilename) - // if len(args) > 1 { - // option = tfile.WithName(genFilename) - // } userDB, err := database.GetUserByChatID(ctx, update.GetUserChat().GetID()) if err != nil { return err diff --git a/client/bot/handlers/utils/msgelem/storage.go b/client/bot/handlers/utils/msgelem/storage.go index e66899b..36421c2 100644 --- a/client/bot/handlers/utils/msgelem/storage.go +++ b/client/bot/handlers/utils/msgelem/storage.go @@ -45,6 +45,8 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add) TphDirPath: adddata.TphDirPath, ParsedItem: adddata.ParsedItem, + + DirectLinks: adddata.DirectLinks, } dataid := xid.New().String() err := cache.Set(dataid, data) diff --git a/client/bot/handlers/utils/shortcut/directlinks.go b/client/bot/handlers/utils/shortcut/directlinks.go new file mode 100644 index 0000000..00140b8 --- /dev/null +++ b/client/bot/handlers/utils/shortcut/directlinks.go @@ -0,0 +1,30 @@ +package shortcut + +import ( + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/core/tasks/directlinks" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +func CreateAndAddDirectTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, links []string, msgID int, userID int64) error { + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) + task := directlinks.NewTask(xid.New().String(), injectCtx, links, stor, stor.JoinStoragePath(dirPath), directlinks.NewProgress(msgID, userID)) + if err := core.AddTask(injectCtx, task); err != nil { + log.FromContext(ctx).Errorf("Failed to add task: %s", err) + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + ID: msgID, + Message: "任务添加失败: " + err.Error(), + }) + return dispatcher.EndGroups + } + ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ + Message: "任务已添加", + }) + return dispatcher.EndGroups +} diff --git a/common/utils/ioutil/writer.go b/common/utils/ioutil/writer.go index f34833a..cdb4ec8 100644 --- a/common/utils/ioutil/writer.go +++ b/common/utils/ioutil/writer.go @@ -1,6 +1,8 @@ package ioutil -import "io" +import ( + "io" +) type ProgressWriterAt struct { wrAt io.WriterAt @@ -46,4 +48,4 @@ func NewProgressWriter( wr: wr, onWrite: onWrite, } -} +} \ No newline at end of file diff --git a/core/tasks/batchtfile/execute.go b/core/tasks/batchtfile/execute.go index 1c9427a..403d250 100644 --- a/core/tasks/batchtfile/execute.go +++ b/core/tasks/batchtfile/execute.go @@ -24,7 +24,7 @@ func (t *Task) Execute(ctx context.Context) error { workers := config.C().Workers eg, gctx := errgroup.WithContext(ctx) eg.SetLimit(workers) - for _, elem := range t.Elems { + for _, elem := range t.elems { eg.Go(func() error { t.processingMu.RLock() if t.processing[elem.ID] != nil { diff --git a/core/tasks/batchtfile/task.go b/core/tasks/batchtfile/task.go index e4c1326..cea53b3 100644 --- a/core/tasks/batchtfile/task.go +++ b/core/tasks/batchtfile/task.go @@ -25,8 +25,8 @@ type TaskElement struct { type Task struct { ID string - Ctx context.Context - Elems []TaskElement + ctx context.Context + elems []TaskElement Progress ProgressTracker IgnoreErrors bool // if true, errors during processing will be ignored downloaded atomic.Int64 @@ -78,8 +78,8 @@ func NewBatchTGFileTask( ) *Task { task := &Task{ ID: id, - Ctx: ctx, - Elems: files, + ctx: ctx, + elems: files, Progress: progress, downloaded: atomic.Int64{}, totalSize: func() int64 { diff --git a/core/tasks/batchtfile/taskinfo.go b/core/tasks/batchtfile/taskinfo.go index 396b7f3..8a36483 100644 --- a/core/tasks/batchtfile/taskinfo.go +++ b/core/tasks/batchtfile/taskinfo.go @@ -44,11 +44,11 @@ func (t *Task) Downloaded() int64 { } func (t *Task) Count() int { - return len(t.Elems) + return len(t.elems) } func (t *Task) Processing() []TaskElementInfo { - processing := make([]TaskElementInfo, 0, len(t.Elems)) + processing := make([]TaskElementInfo, 0, len(t.elems)) for _, elem := range t.processing { processing = append(processing, elem) } diff --git a/core/tasks/directlinks/execute.go b/core/tasks/directlinks/execute.go new file mode 100644 index 0000000..87791e2 --- /dev/null +++ b/core/tasks/directlinks/execute.go @@ -0,0 +1,167 @@ +package directlinks + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "path/filepath" + "sync/atomic" + + "github.com/charmbracelet/log" + "github.com/duke-git/lancet/v2/retry" + "github.com/krau/SaveAny-Bot/common/utils/fsutil" + "github.com/krau/SaveAny-Bot/common/utils/ioutil" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" + "golang.org/x/sync/errgroup" +) + +func (t *Task) Execute(ctx context.Context) error { + logger := log.FromContext(ctx) + logger.Infof("Starting directlinks task %s", t.ID) + if t.Progress != nil { + t.Progress.OnStart(ctx, t) + } + // head all links to get file info + eg, gctx := errgroup.WithContext(ctx) + eg.SetLimit(config.C().Workers) + fetchedTotalBytes := atomic.Int64{} + for _, file := range t.files { + eg.Go(func() error { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, file.URL, nil) + if err != nil { + return fmt.Errorf("failed to create HEAD request for %s: %w", file.URL, err) + } + resp, err := t.client.Do(req) + if err != nil { + return fmt.Errorf("failed to HEAD %s: %w", file.URL, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("HEAD %s returned status %d", file.URL, resp.StatusCode) + } + fetchedTotalBytes.Add(resp.ContentLength) + file.Size = resp.ContentLength + if name := resp.Header.Get("Content-Disposition"); name != "" { + // Set file name + filename := parseFilename(name) + file.Name = filename + } + + return nil + }) + } + err := eg.Wait() + if err != nil { + logger.Errorf("Error during HEAD requests: %v", err) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return err + } + t.totalBytes = fetchedTotalBytes.Load() + // start downloading + eg, gctx = errgroup.WithContext(ctx) + eg.SetLimit(config.C().Workers) + for _, file := range t.files { + eg.Go(func() error { + t.processingMu.RLock() + if _, ok := t.processing[file.URL]; ok { + return fmt.Errorf("file %s is already being processed", file.URL) + } + t.processingMu.RUnlock() + t.processingMu.Lock() + t.processing[file.URL] = file + t.processingMu.Unlock() + defer func() { + t.processingMu.Lock() + delete(t.processing, file.URL) + t.processingMu.Unlock() + }() + err := t.processLink(gctx, file) + t.downloaded.Add(1) + if errors.Is(err, context.Canceled) { + logger.Debug("Link processing canceled") + return err + } + if err != nil { + logger.Errorf("Error processing link %s: %v", file.URL, err) + return fmt.Errorf("failed to process link %s: %w", file.URL, err) + } + return nil + }) + } + err = eg.Wait() + if err != nil { + logger.Errorf("Error during directlinks task execution: %v", err) + } else { + logger.Infof("Directlinks task %s completed successfully", t.ID) + } + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } + return err +} + +func (t *Task) processLink(ctx context.Context, file *File) error { + logger := log.FromContext(ctx) + err := retry.Retry(func() error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, file.URL, nil) + if err != nil { + return fmt.Errorf("failed to create GET request for %s: %w", file.URL, err) + } + resp, err := t.client.Do(req) + if err != nil { + return fmt.Errorf("failed to GET %s: %w", file.URL, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("GET %s returned status %d", file.URL, resp.StatusCode) + } + ctx = context.WithValue(ctx, ctxkey.ContentLength, file.Size) + if t.stream { + return t.Storage.Save(ctx, resp.Body, filepath.Join(t.StorPath, file.Name)) + } + cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath, + fmt.Sprintf("direct_%s_%s", t.ID, file.Name))) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer func() { + if err := cacheFile.CloseAndRemove(); err != nil { + logger.Errorf("Failed to close and remove cache file: %v", err) + } + }() + wr := ioutil.NewProgressWriter(cacheFile, func(n int) { + t.downloadedBytes.Add(int64(n)) + if t.Progress != nil { + t.Progress.OnProgress(ctx, t) + } + }) + + copyResultCh := make(chan error, 1) + go func() { + _, err := io.Copy(wr, resp.Body) + copyResultCh <- err + }() + select { + case err := <-copyResultCh: + if err != nil { + return fmt.Errorf("failed to copy file %s to cache file: %w", file.URL, err) + } + case <-ctx.Done(): + return ctx.Err() + } + _, err = cacheFile.Seek(0, 0) + if err != nil { + return fmt.Errorf("failed to seek cache file for resource %s: %w", file.URL, err) + } + return t.Storage.Save(ctx, cacheFile, filepath.Join(t.StorPath, file.Name)) + }, retry.RetryTimes(uint(config.C().Retry)), retry.Context(ctx)) + if ctx.Err() != nil { + return ctx.Err() + } + return err +} diff --git a/core/tasks/directlinks/progress.go b/core/tasks/directlinks/progress.go new file mode 100644 index 0000000..45aa3b9 --- /dev/null +++ b/core/tasks/directlinks/progress.go @@ -0,0 +1,196 @@ +package directlinks + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "time" + + "github.com/charmbracelet/log" + "github.com/duke-git/lancet/v2/slice" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/utils/dlutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" +) + +type TaskInfo interface { + TotalBytes() int64 + TotalFiles() int + TaskID() string + StorageName() string + StoragePath() string + DownloadedBytes() int64 + Processing() []FileInfo +} + +type FileInfo interface { + FileName() string + FileSize() int64 +} + +type ProgressTracker interface { + OnStart(ctx context.Context, info TaskInfo) + OnProgress(ctx context.Context, info TaskInfo) + OnDone(ctx context.Context, info TaskInfo, err error) +} + +type Progress struct { + msgID int + chatID int64 + start time.Time + lastUpdatePercent atomic.Int32 +} + +// OnDone implements ProgressTracker. +func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) { + logger := log.FromContext(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + logger.Infof("Parsed task %s was canceled", info.TaskID()) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{ + ID: p.msgID, + Message: fmt.Sprintf("处理已取消: %s", info.TaskID()), + }) + } + } else { + logger.Errorf("Parsed task %s failed: %s", info.TaskID(), err) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{ + ID: p.msgID, + Message: fmt.Sprintf("处理失败: %s", err.Error()), + }) + } + } + return + } + logger.Infof("Parsed task %s completed successfully", info.TaskID()) + + entityBuilder := entity.Builder{} + if err := styling.Perform(&entityBuilder, + styling.Plain("处理完成, 文件数量: "), + styling.Code(fmt.Sprintf("%d", info.TotalFiles())), + styling.Plain("\n保存路径: "), + styling.Code(fmt.Sprintf("[%s]:%s", info.StorageName(), info.StoragePath())), + ); err != nil { + logger.Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.msgID, + } + req.SetMessage(text) + req.SetEntities(entities) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, req) + } +} + +// OnProgress implements ProgressTracker. +func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) { + if !shouldUpdateProgress(info.TotalBytes(), info.DownloadedBytes(), int(p.lastUpdatePercent.Load())) { + return + } + percent := int((info.DownloadedBytes() * 100) / info.TotalBytes()) + if p.lastUpdatePercent.Load() == int32(percent) { + return + } + p.lastUpdatePercent.Store(int32(percent)) + log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.DownloadedBytes(), info.TotalBytes()) + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain("正在下载\n总大小: "), + styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalBytes())/(1024*1024), info.TotalFiles())), + styling.Plain("\n正在处理:\n"), + func() styling.StyledTextOption { + var lines []string + for _, elem := range info.Processing() { + lines = append(lines, fmt.Sprintf(" - %s (%.2f MB)", elem.FileName(), float64(elem.FileSize())/(1024*1024))) + } + if len(lines) == 0 { + lines = append(lines, " - 无") + } + return styling.Plain(slice.Join(lines, "\n")) + }(), + styling.Plain("\n平均速度: "), + styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(info.DownloadedBytes(), p.start)/(1024*1024))), + styling.Plain("\n当前进度: "), + styling.Bold(fmt.Sprintf("%.2f%%", float64(info.DownloadedBytes())/float64(info.TotalBytes())*100)), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.msgID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.chatID, req) + return + } +} + +// OnStart implements ProgressTracker. +func (p *Progress) OnStart(ctx context.Context, info TaskInfo) { + logger := log.FromContext(ctx) + p.start = time.Now() + p.lastUpdatePercent.Store(0) + logger.Infof("Direct links task started: message_id=%d, chat_id=%d", p.msgID, p.chatID) + ext := tgutil.ExtFromContext(ctx) + if ext == nil { + return + } + entityBuilder := entity.Builder{} + var entities []tg.MessageEntityClass + if err := styling.Perform(&entityBuilder, + styling.Plain(fmt.Sprintf("开始下载, 总大小: %.2f MB (%d 个文件)", float64(info.TotalBytes())/(1024*1024), info.TotalFiles()))); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.msgID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }}, + ) + ext.EditMessage(p.chatID, req) +} + +var _ ProgressTracker = (*Progress)(nil) + +func NewProgress(msgID int, userID int64) ProgressTracker { + return &Progress{ + msgID: msgID, + chatID: userID, + } +} diff --git a/core/tasks/directlinks/task.go b/core/tasks/directlinks/task.go new file mode 100644 index 0000000..d5aeacd --- /dev/null +++ b/core/tasks/directlinks/task.go @@ -0,0 +1,121 @@ +package directlinks + +import ( + "context" + "net/http" + "sync" + "sync/atomic" + + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/storage" +) + +type File struct { + Name string + URL string + Size int64 +} + +func (f *File) FileName() string { + return f.Name +} + +func (f *File) FileSize() int64 { + return f.Size +} + +type Task struct { + ID string + ctx context.Context + files []*File + Storage storage.Storage + StorPath string + Progress ProgressTracker + + client *http.Client // [TODO] parallel download + stream bool + totalBytes int64 // total bytes to download + downloadedBytes atomic.Int64 // downloaded bytes + totalFiles int64 // total files to download + downloaded atomic.Int64 // downloaded files count + processing map[string]*File // {"url": File} + processingMu sync.RWMutex + failed map[string]error // [TODO] errors for each file +} + +// DownloadedBytes implements TaskInfo. +func (t *Task) DownloadedBytes() int64 { + return t.downloadedBytes.Load() +} + +// Processing implements TaskInfo. +func (t *Task) Processing() []FileInfo { + t.processingMu.RLock() + defer t.processingMu.RUnlock() + infos := make([]FileInfo, 0, len(t.processing)) + for _, f := range t.processing { + infos = append(infos, f) + } + return infos +} + +// StorageName implements TaskInfo. +func (t *Task) StorageName() string { + return t.Storage.Name() +} + +// StoragePath implements TaskInfo. +func (t *Task) StoragePath() string { + return t.StorPath +} + +// TotalBytes implements TaskInfo. +func (t *Task) TotalBytes() int64 { + return t.totalBytes +} + +// TotalFiles implements TaskInfo. +func (t *Task) TotalFiles() int { + return int(t.totalFiles) +} + +func (t *Task) Type() tasktype.TaskType { + return tasktype.TaskTypeDirectlinks +} + +func (t *Task) TaskID() string { + return t.ID +} + +func NewTask( + id string, + ctx context.Context, + links []string, + stor storage.Storage, + storPath string, + progressTracker ProgressTracker, +) *Task { + _, ok := stor.(storage.StorageCannotStream) + stream := config.C().Stream && !ok + files := make([]*File, 0, len(links)) + for _, link := range links { + files = append(files, &File{ + URL: link, + }) + } + return &Task{ + ID: id, + ctx: ctx, + files: files, + Storage: stor, + StorPath: storPath, + Progress: progressTracker, + stream: stream, + client: http.DefaultClient, + processing: make(map[string]*File), + processingMu: sync.RWMutex{}, + failed: make(map[string]error), + totalFiles: int64(len(files)), + } +} diff --git a/core/tasks/directlinks/util.go b/core/tasks/directlinks/util.go new file mode 100644 index 0000000..8a64d76 --- /dev/null +++ b/core/tasks/directlinks/util.go @@ -0,0 +1,205 @@ +package directlinks + +import ( + "mime" + "net/url" + "strings" + "unicode/utf8" + + "golang.org/x/text/encoding/simplifiedchinese" +) + +// parseFilename extracts filename from Content-Disposition header +// It handles multiple encoding scenarios: +// 1. RFC 5987/RFC 2231 format: filename*=UTF-8”%E6%B5%8B%E8%AF%95.zip (preferred, checked first) +// 2. MIME encoded-word: filename="=?UTF-8?B?5rWL6K+VLnppcA==?=" +// 3. URL-encoded: filename="%E6%B5%8B%E8%AF%95.zip" +// 4. Plain ASCII filename +// +// The key fix is checking filename*= first before mime.ParseMediaType, because +// some servers send Content-Disposition headers with invalid characters that cause +// mime.ParseMediaType to fail, but the filename*= parameter is still valid. +func parseFilename(contentDisposition string) string { + // First, try to find filename*= (RFC 5987 format, most reliable for non-ASCII) + if filename := parseFilenameExtended(contentDisposition); filename != "" { + return filename + } + + // Try standard MIME parsing for regular filename= parameter + _, params, err := mime.ParseMediaType(contentDisposition) + if err == nil { + if filename := params["filename"]; filename != "" { + return decodeFilenameParam(filename) + } + } + + // Fallback: manual parsing if mime.ParseMediaType fails + return parseFilenameFallback(contentDisposition) +} + +// parseFilenameExtended parses RFC 5987/RFC 2231 extended parameter format +// Format: filename*=charset'language'value (e.g., UTF-8”%E6%B5%8B%E8%AF%95.zip) +func parseFilenameExtended(cd string) string { + // Look for filename*= (case-insensitive) + lower := strings.ToLower(cd) + idx := strings.Index(lower, "filename*=") + if idx == -1 { + return "" + } + + // Extract the value after filename*= + value := cd[idx+len("filename*="):] + + // Find the end of the value (next ; or end of string) + if endIdx := strings.Index(value, ";"); endIdx != -1 { + value = value[:endIdx] + } + value = strings.TrimSpace(value) + + // Parse charset'language'encoded-value format + // Common format: UTF-8''%E6%B5%8B%E8%AF%95.zip + parts := strings.SplitN(value, "''", 2) + if len(parts) == 2 { + // parts[0] is charset (e.g., "UTF-8") + // parts[1] is percent-encoded value + decoded, err := url.QueryUnescape(parts[1]) + if err == nil { + return decoded + } + } + + // Try with single quote delimiter as well (some servers use this) + parts = strings.SplitN(value, "'", 3) + if len(parts) >= 3 { + decoded, err := url.QueryUnescape(parts[2]) + if err == nil { + return decoded + } + } + + return "" +} + +// TryUrlQueryUnescape tries to unescape a URL-encoded string. +// +// If unescaping fails, it returns the original string. +func tryUrlQueryUnescape(s string) string { + if decoded, err := url.QueryUnescape(s); err == nil { + return decoded + } + return s +} + +// decodeFilenameParam decodes a filename parameter value +// Handles MIME encoded-word, URL encoding, and GBK encoding fallback +func decodeFilenameParam(filename string) string { + // Check if the filename is MIME encoded-word (e.g., =?UTF-8?B?...?=) + if strings.HasPrefix(filename, "=?") { + decoder := new(mime.WordDecoder) + // Some servers use "UTF8" instead of "UTF-8", create a normalized copy + normalizedFilename := strings.Replace(filename, "UTF8", "UTF-8", 1) + if decoded, err := decoder.Decode(normalizedFilename); err == nil { + return decoded + } + } + + // Try URL decoding + decoded := tryUrlQueryUnescape(filename) + + // Check if the result is valid UTF-8. If not, try GBK decoding. + // This handles the case where Chinese Windows servers send GBK-encoded filenames + // which appear as garbled characters (e.g., "下载地址.zip" -> "���ص�ַ.zip") + if !utf8.ValidString(decoded) { + if gbkDecoded := tryDecodeGBK(decoded); gbkDecoded != "" { + return gbkDecoded + } + } + + return decoded +} + +// gbkDecoder is a reusable GBK decoder for better performance +var gbkDecoder = simplifiedchinese.GBK.NewDecoder() + +// tryDecodeGBK attempts to decode a string as GBK/GB2312/GB18030 encoding +// Returns empty string if decoding fails or result is not valid UTF-8 +func tryDecodeGBK(s string) string { + // GBK uses 1-2 bytes per character. Single-byte chars are 0x00-0x7F (ASCII compatible). + // Double-byte chars have first byte 0x81-0xFE and second byte 0x40-0xFE. + // Skip if string is empty or all ASCII (valid UTF-8) + if len(s) == 0 { + return "" + } + + // Create a fresh decoder since the transform state may be corrupted + decoder := gbkDecoder + decoded, err := decoder.Bytes([]byte(s)) + if err != nil { + return "" + } + result := string(decoded) + if utf8.ValidString(result) { + return result + } + return "" +} + +// parseFilenameFallback manually parses filename= when mime.ParseMediaType fails +func parseFilenameFallback(cd string) string { + // Look for filename= (case-insensitive) + lower := strings.ToLower(cd) + idx := strings.Index(lower, "filename=") + if idx == -1 { + return "" + } + + // Skip "filename=" prefix + value := cd[idx+len("filename="):] + + // Find the end of the value + if endIdx := strings.Index(value, ";"); endIdx != -1 { + value = value[:endIdx] + } + value = strings.TrimSpace(value) + + // Remove quotes if present + if len(value) >= 2 { + if (value[0] == '"' && value[len(value)-1] == '"') || + (value[0] == '\'' && value[len(value)-1] == '\'') { + value = value[1 : len(value)-1] + } + } + + return decodeFilenameParam(value) +} + +var progressUpdatesLevels = []struct { + size int64 // 文件大小阈值 + stepPercent int // 每多少 % 更新一次 +}{ + {10 << 20, 100}, + {50 << 20, 50}, + {200 << 20, 20}, + {500 << 20, 10}, +} + +func shouldUpdateProgress(total, downloaded int64, lastUpdatePercent int) bool { + if total <= 0 || downloaded <= 0 { + return false + } + + percent := int((downloaded * 100) / total) + if percent <= lastUpdatePercent { + return false + } + + step := progressUpdatesLevels[len(progressUpdatesLevels)-1].stepPercent + for _, lvl := range progressUpdatesLevels { + if total < lvl.size { + step = lvl.stepPercent + break + } + } + + return percent >= lastUpdatePercent+step +} diff --git a/core/tasks/parsed/task.go b/core/tasks/parsed/task.go index 281aa99..fae7106 100644 --- a/core/tasks/parsed/task.go +++ b/core/tasks/parsed/task.go @@ -19,9 +19,9 @@ type Task struct { Stor storage.Storage StorPath string item *parser.Item - httpClient *http.Client - progress ProgressTracker - stream bool + httpClient *http.Client // [TODO] btorrent support? + progress ProgressTracker + stream bool totalResources int64 downloaded atomic.Int64 // downloaded resources count diff --git a/go.sum b/go.sum index 51012d2..8e58281 100644 --- a/go.sum +++ b/go.sum @@ -106,8 +106,6 @@ github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GM github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg= github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo= -github.com/go-faster/jx v1.1.0 h1:ZsW3wD+snOdmTDy9eIVgQdjUpXRRV4rqW8NS3t+20bg= -github.com/go-faster/jx v1.1.0/go.mod h1:vKDNikrKoyUmpzaJ0OkIkRQClNHFX/nF3dnTJZb3skg= github.com/go-faster/jx v1.2.0 h1:T2YHJPrFaYu21fJtUxC9GzmluKu8rVIFDwwGBKTDseI= github.com/go-faster/jx v1.2.0/go.mod h1:UWLOVDmMG597a5tBFPLIWJdUxz5/2emOpfsj9Neg0PE= github.com/go-faster/xor v0.3.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ= @@ -155,8 +153,6 @@ github.com/gotd/ige v0.2.2 h1:XQ9dJZwBfDnOGSTxKXBGP4gMud3Qku2ekScRjDWWfEk= github.com/gotd/ige v0.2.2/go.mod h1:tuCRb+Y5Y3eNTo3ypIfNpQ4MFjrnONiL2jN2AKZXmb0= github.com/gotd/neo v0.1.5 h1:oj0iQfMbGClP8xI59x7fE/uHoTJD7NZH9oV1WNuPukQ= github.com/gotd/neo v0.1.5/go.mod h1:9A2a4bn9zL6FADufBdt7tZt+WMhvZoc5gWXihOPoiBQ= -github.com/gotd/td v0.132.0 h1:Iqm3S2b+8kDgA9237IDXRxj7sryUpvy+4Cr50/0tpx4= -github.com/gotd/td v0.132.0/go.mod h1:4CDGYS+rDtOqotRheGaF9MS5g6jaUewvSXqBNJnx8SQ= github.com/gotd/td v0.136.0 h1:f7vx/1rlvP59L5EKR820XpMRO2k267wW8/F0rAWbepc= github.com/gotd/td v0.136.0/go.mod h1:mStcqs/9FXhNhWnPTguptSwqkQbRIwXLw3SCSpzPJxM= github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf h1:WfD7VjIE6z8dIvMsI4/s+1qr5EL+zoIGev1BQj1eoJ8= @@ -169,8 +165,6 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/johannesboyne/gofakes3 v0.0.0-20250916175020-ebf3e50324d3 h1:2713fQZ560HxoNVgfJH41GKzjMjIG+DW4hH6nYXfXW8= github.com/johannesboyne/gofakes3 v0.0.0-20250916175020-ebf3e50324d3/go.mod h1:S4S9jGBVlLri0OeqrSSbCGG5vsI6he06UJyuz1WT1EE= -github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= -github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -290,8 +284,6 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= -go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -305,8 +297,6 @@ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2 golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= -golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -356,8 +346,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= -golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/parsers/parser.go b/parsers/parser.go index 2f13f52..2d44823 100644 --- a/parsers/parser.go +++ b/parsers/parser.go @@ -49,6 +49,7 @@ func ParseWithContext(ctx context.Context, url string) (*parser.Item, error) { } } +// CanHandle checks if any registered parser can handle the given URL and returns the parser if found. func CanHandle(url string) (bool, parser.Parser) { for _, pser := range parsers.Get() { if pser.CanHandle(url) { diff --git a/pkg/enums/tasktype/tasktype.go b/pkg/enums/tasktype/tasktype.go index 5621674..454a08a 100644 --- a/pkg/enums/tasktype/tasktype.go +++ b/pkg/enums/tasktype/tasktype.go @@ -1,5 +1,5 @@ package tasktype //go:generate go-enum --values --names --flag --nocase -// ENUM(tgfiles,tphpics,parseditem) +// ENUM(tgfiles,tphpics,parseditem,directlinks) type TaskType string diff --git a/pkg/enums/tasktype/tasktype_enum.go b/pkg/enums/tasktype/tasktype_enum.go index 69799f4..df0f5ae 100644 --- a/pkg/enums/tasktype/tasktype_enum.go +++ b/pkg/enums/tasktype/tasktype_enum.go @@ -18,6 +18,8 @@ const ( TaskTypeTphpics TaskType = "tphpics" // TaskTypeParseditem is a TaskType of type parseditem. TaskTypeParseditem TaskType = "parseditem" + // TaskTypeDirectlinks is a TaskType of type directlinks. + TaskTypeDirectlinks TaskType = "directlinks" ) var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", ")) @@ -26,6 +28,7 @@ var _TaskTypeNames = []string{ string(TaskTypeTgfiles), string(TaskTypeTphpics), string(TaskTypeParseditem), + string(TaskTypeDirectlinks), } // TaskTypeNames returns a list of possible string values of TaskType. @@ -41,6 +44,7 @@ func TaskTypeValues() []TaskType { TaskTypeTgfiles, TaskTypeTphpics, TaskTypeParseditem, + TaskTypeDirectlinks, } } @@ -57,9 +61,10 @@ func (x TaskType) IsValid() bool { } var _TaskTypeValue = map[string]TaskType{ - "tgfiles": TaskTypeTgfiles, - "tphpics": TaskTypeTphpics, - "parseditem": TaskTypeParseditem, + "tgfiles": TaskTypeTgfiles, + "tphpics": TaskTypeTphpics, + "parseditem": TaskTypeParseditem, + "directlinks": TaskTypeDirectlinks, } // ParseTaskType attempts to convert a string to a TaskType. diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index c96fc6a..b537df5 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -55,7 +55,7 @@ func (r *Resource) ID() string { h.Write([]byte(r.Filename)) h.Write([]byte(r.MimeType)) h.Write([]byte(r.Extension)) - h.Write([]byte(fmt.Sprintf("%d", r.Size))) + fmt.Fprintf(h, "%d", r.Size) for k, v := range r.Hash { h.Write([]byte(k)) diff --git a/pkg/tcbdata/data.go b/pkg/tcbdata/data.go index 980823d..84691b8 100644 --- a/pkg/tcbdata/data.go +++ b/pkg/tcbdata/data.go @@ -43,6 +43,8 @@ type Add struct { TphDirPath string // unescaped telegraph.Page.Path // parseditem ParsedItem *parser.Item + // directlinks + DirectLinks []string } type SetDefaultStorage struct {