From 3e20dc2c5f66728e4dc68114346326865e21dee6 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 19 Jan 2026 13:10:21 +0800 Subject: [PATCH] feat: add custom parameter support to /ytdlp command (#185), close #184 * Initial plan * Implement parameter support for /ytdlp command Co-authored-by: krau <71133316+krau@users.noreply.github.com> * Add comprehensive tests for ytdlp parameter parsing Co-authored-by: krau <71133316+krau@users.noreply.github.com> * Improve flag parsing logic and clarify argument order Co-authored-by: krau <71133316+krau@users.noreply.github.com> * Preserve critical defaults and improve comments Co-authored-by: krau <71133316+krau@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: krau <71133316+krau@users.noreply.github.com> --- client/bot/handlers/add_task.go | 2 +- client/bot/handlers/dl.go | 2 +- client/bot/handlers/media_group.go | 2 +- client/bot/handlers/tasks.go | 2 +- client/bot/handlers/update.go | 6 +- client/bot/handlers/utils/mediautil/media.go | 2 +- client/bot/handlers/utils/msgelem/storage.go | 5 +- client/bot/handlers/utils/shortcut/parsed.go | 2 +- client/bot/handlers/utils/shortcut/tftask.go | 28 ++-- client/bot/handlers/utils/shortcut/tphtask.go | 4 +- client/bot/handlers/utils/shortcut/ytdlp.go | 5 +- client/bot/handlers/ytdlp.go | 55 ++++++-- client/bot/handlers/ytdlp_test.go | 129 ++++++++++++++++++ common/i18n/locale/en.yaml | 2 +- common/i18n/locale/zh-Hans.yaml | 2 +- common/utils/ioutil/writer.go | 2 +- core/tasks/ytdlp/execute.go | 28 ++-- core/tasks/ytdlp/task.go | 3 + core/tasks/ytdlp/task_test.go | 114 ++++++++++++++++ database/user.go | 2 +- pkg/enums/ctxkey/context_key.go | 3 +- pkg/enums/tasktype/tasktype.go | 3 +- pkg/tcbdata/data.go | 3 +- pkg/tfile/opts.go | 2 +- 24 files changed, 350 insertions(+), 58 deletions(-) create mode 100644 client/bot/handlers/ytdlp_test.go create mode 100644 core/tasks/ytdlp/task_test.go diff --git a/client/bot/handlers/add_task.go b/client/bot/handlers/add_task.go index 858436c..a29886a 100644 --- a/client/bot/handlers/add_task.go +++ b/client/bot/handlers/add_task.go @@ -100,7 +100,7 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error { } shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID) case tasktype.TaskTypeYtdlp: - shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, msgID, userID) + shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, data.YtdlpFlags, msgID, userID) default: return fmt.Errorf("unexcept task type: %s", data.TaskType) } diff --git a/client/bot/handlers/dl.go b/client/bot/handlers/dl.go index 518886d..d231ec1 100644 --- a/client/bot/handlers/dl.go +++ b/client/bot/handlers/dl.go @@ -84,7 +84,7 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error { return nil } logger.Debug("Preparing aria2 download", "links", links) - + // Initialize aria2 client to check connection aria2ClientInitOnce.Do(func() { aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret) diff --git a/client/bot/handlers/media_group.go b/client/bot/handlers/media_group.go index 2b11738..bd3956c 100644 --- a/client/bot/handlers/media_group.go +++ b/client/bot/handlers/media_group.go @@ -114,7 +114,7 @@ func processMediaGroup(ctx *ext.Context, update *ext.Update, groupID int64) { if err != nil { logger.Errorf("Failed to build storage selection keyboard: %s", err) ctx.EditMessage(userId, &tg.MessagesEditMessageRequest{ - ID: msg.ID, + ID: msg.ID, Message: i18n.T(i18nk.BotMsgMediaGroupErrorBuildStorageSelectKeyboardFailed, map[string]any{ "Error": err.Error(), }), diff --git a/client/bot/handlers/tasks.go b/client/bot/handlers/tasks.go index c5b7ca1..1d6a2d2 100644 --- a/client/bot/handlers/tasks.go +++ b/client/bot/handlers/tasks.go @@ -38,7 +38,7 @@ func handleTaskCmd(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } ctx.Reply(update, ext.ReplyTextStyledTextArray([]styling.StyledTextOption{ - styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)), + styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)), styling.Code(taskID), }), nil) default: diff --git a/client/bot/handlers/update.go b/client/bot/handlers/update.go index fae434a..933a9b1 100644 --- a/client/bot/handlers/update.go +++ b/client/bot/handlers/update.go @@ -103,7 +103,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error { return err } ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{ - ID: u.CallbackQuery.GetMsgID(), + ID: u.CallbackQuery.GetMsgID(), Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradingWithVersion, map[string]any{ "Current": config.Version, }), @@ -111,7 +111,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error { latest, err := ghselfupdate.UpdateSelf(currentV, config.GitRepo) if err != nil { ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{ - ID: u.CallbackQuery.GetMsgID(), + ID: u.CallbackQuery.GetMsgID(), Message: i18n.T(i18nk.BotMsgUpdateErrorUpgradeFailed, map[string]any{ "Error": err.Error(), }), @@ -119,7 +119,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error { return dispatcher.EndGroups } ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{ - ID: u.CallbackQuery.GetMsgID(), + ID: u.CallbackQuery.GetMsgID(), Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradeSuccess, map[string]any{ "Version": latest.Version.String(), }), diff --git a/client/bot/handlers/utils/mediautil/media.go b/client/bot/handlers/utils/mediautil/media.go index 8917c52..88eaeb3 100644 --- a/client/bot/handlers/utils/mediautil/media.go +++ b/client/bot/handlers/utils/mediautil/media.go @@ -112,7 +112,7 @@ func BuildFilenameTemplateData(message *tg.Message) map[string]string { }(), MsgRaw: message.GetMessage(), ChatID: func() string { - // 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id, + // 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id, // 无论它是否是从其他来源转发的 if message.GetPost() { peer := message.GetPeerID() diff --git a/client/bot/handlers/utils/msgelem/storage.go b/client/bot/handlers/utils/msgelem/storage.go index e4edae0..71828ba 100644 --- a/client/bot/handlers/utils/msgelem/storage.go +++ b/client/bot/handlers/utils/msgelem/storage.go @@ -50,8 +50,9 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add) DirectLinks: adddata.DirectLinks, - Aria2URIs: adddata.Aria2URIs, - YtdlpURLs: adddata.YtdlpURLs, + Aria2URIs: adddata.Aria2URIs, + YtdlpURLs: adddata.YtdlpURLs, + YtdlpFlags: adddata.YtdlpFlags, } dataid := xid.New().String() err := cache.Set(dataid, data) diff --git a/client/bot/handlers/utils/shortcut/parsed.go b/client/bot/handlers/utils/shortcut/parsed.go index dc2fe8d..1cdceae 100644 --- a/client/bot/handlers/utils/shortcut/parsed.go +++ b/client/bot/handlers/utils/shortcut/parsed.go @@ -22,7 +22,7 @@ func CreateAndAddParsedTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirP if err := core.AddTask(injectCtx, task); err != nil { log.FromContext(ctx).Errorf("Failed to add task: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: msgID, + ID: msgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), diff --git a/client/bot/handlers/utils/shortcut/tftask.go b/client/bot/handlers/utils/shortcut/tftask.go index c63e8d6..72fa7a0 100644 --- a/client/bot/handlers/utils/shortcut/tftask.go +++ b/client/bot/handlers/utils/shortcut/tftask.go @@ -29,7 +29,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage if err != nil { logger.Errorf("Failed to get user by chat ID: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{ "Error": err.Error(), }), @@ -49,7 +49,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage if err != nil { logger.Errorf("Failed to get storage by user ID and name: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{ "Error": err.Error(), }), @@ -69,7 +69,7 @@ startCreateTask: if err != nil { logger.Errorf("create task failed: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ "Error": err.Error(), }), @@ -79,7 +79,7 @@ startCreateTask: if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("add task failed: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), @@ -103,7 +103,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to get user by chat ID: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{ "Error": err.Error(), }), @@ -142,7 +142,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to get storage by user ID and name: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{ "Error": err.Error(), }), @@ -156,10 +156,10 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to create task element: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, - Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ - "Error": err.Error(), - }), + ID: trackMsgID, + Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ + "Error": err.Error(), + }), }) return dispatcher.EndGroups } @@ -193,7 +193,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to create task element for album file: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ "Error": err.Error(), }), @@ -210,7 +210,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("Failed to add batch task: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), @@ -218,8 +218,8 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st return dispatcher.EndGroups } ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, - Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{ + ID: trackMsgID, + Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{ "Count": len(files), }), ReplyMarkup: nil, diff --git a/client/bot/handlers/utils/shortcut/tphtask.go b/client/bot/handlers/utils/shortcut/tphtask.go index 6245750..02e3fec 100644 --- a/client/bot/handlers/utils/shortcut/tphtask.go +++ b/client/bot/handlers/utils/shortcut/tphtask.go @@ -25,7 +25,7 @@ func CreateAndAddtelegraphWithEdit( pics []string, stor storage.Storage, trackMsgID int) error { - + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) task := tphtask.NewTask(xid.New().String(), injectCtx, @@ -39,7 +39,7 @@ func CreateAndAddtelegraphWithEdit( if err := core.AddTask(injectCtx, task); err != nil { log.FromContext(ctx).Errorf("Failed to add task: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), diff --git a/client/bot/handlers/utils/shortcut/ytdlp.go b/client/bot/handlers/utils/shortcut/ytdlp.go index 9cdefbd..038ece7 100644 --- a/client/bot/handlers/utils/shortcut/ytdlp.go +++ b/client/bot/handlers/utils/shortcut/ytdlp.go @@ -15,7 +15,7 @@ import ( "github.com/krau/SaveAny-Bot/storage" ) -func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, msgID int, userID int64) error { +func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, flags []string, msgID int, userID int64) error { logger := log.FromContext(ctx) injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) @@ -29,13 +29,14 @@ func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPa return dispatcher.EndGroups } - logger.Infof("Creating yt-dlp task for %d URL(s)", len(urls)) + logger.Infof("Creating yt-dlp task for %d URL(s) with %d flag(s)", len(urls), len(flags)) // Create yt-dlp task task := ytdlp.NewTask( xid.New().String(), injectCtx, urls, + flags, stor, stor.JoinStoragePath(dirPath), ytdlp.NewProgress(msgID, userID), diff --git a/client/bot/handlers/ytdlp.go b/client/bot/handlers/ytdlp.go index 614e8d8..4140afa 100644 --- a/client/bot/handlers/ytdlp.go +++ b/client/bot/handlers/ytdlp.go @@ -7,7 +7,6 @@ import ( "github.com/celestix/gotgproto/dispatcher" "github.com/celestix/gotgproto/ext" "github.com/charmbracelet/log" - "github.com/duke-git/lancet/v2/slice" "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" "github.com/krau/SaveAny-Bot/common/i18n" @@ -25,29 +24,59 @@ func handleYtdlpCmd(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } - urls := args[1:] - // Validate and clean URLs - for i, link := range urls { - urls[i] = strings.TrimSpace(link) - u, err := url.Parse(link) - if err != nil || u.Scheme == "" || u.Host == "" { - logger.Warnf("Invalid URL: %s", link) - urls[i] = "" + // Separate URLs and flags from arguments + var urls []string + var flags []string + + for i := 1; i < len(args); i++ { + arg := strings.TrimSpace(args[i]) + if arg == "" { + continue + } + + // Check if it's a flag (starts with - or --) + if strings.HasPrefix(arg, "-") { + flags = append(flags, arg) + // Check if the next argument might be a value for this flag + // Don't consume it if it starts with - or looks like a URL with scheme + if i+1 < len(args) { + nextArg := strings.TrimSpace(args[i+1]) + if nextArg != "" && !strings.HasPrefix(nextArg, "-") { + // Check if it's clearly a URL (has ://) + // This handles common video URLs (http://, https://) + // For other yt-dlp inputs, users should ensure proper formatting + if strings.Contains(nextArg, "://") { + // It's a URL, don't consume it as a flag value + continue + } + // Otherwise, treat it as a flag value + flags = append(flags, nextArg) + i++ // Skip the next argument as it's been consumed + } + } + } else { + // Try to parse as URL + u, err := url.Parse(arg) + if err != nil || u.Scheme == "" || u.Host == "" { + logger.Warnf("Invalid URL: %s", arg) + continue + } + urls = append(urls, arg) } } - urls = slice.Compact(urls) if len(urls) == 0 { ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls)), nil) return dispatcher.EndGroups } - logger.Debugf("Preparing yt-dlp download for %d URL(s)", len(urls)) + logger.Debugf("Preparing yt-dlp download for %d URL(s) with %d flag(s)", len(urls), len(flags)) // Build storage selection keyboard markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{ - TaskType: tasktype.TaskTypeYtdlp, - YtdlpURLs: urls, + TaskType: tasktype.TaskTypeYtdlp, + YtdlpURLs: urls, + YtdlpFlags: flags, }) if err != nil { return err diff --git a/client/bot/handlers/ytdlp_test.go b/client/bot/handlers/ytdlp_test.go new file mode 100644 index 0000000..39fb491 --- /dev/null +++ b/client/bot/handlers/ytdlp_test.go @@ -0,0 +1,129 @@ +package handlers + +import ( + "net/url" + "strings" + "testing" +) + +// TestYtdlpArgumentParsing tests the URL and flag separation logic +func TestYtdlpArgumentParsing(t *testing.T) { + tests := []struct { + name string + input string + expectedURLs []string + expectedFlags []string + }{ + { + name: "Single URL without flags", + input: "/ytdlp https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{}, + }, + { + name: "Multiple URLs without flags", + input: "/ytdlp https://example.com/v1 https://example.com/v2", + expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"}, + expectedFlags: []string{}, + }, + { + name: "URL with format flag", + input: "/ytdlp --format best https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"--format", "best"}, + }, + { + name: "URL with extract-audio flag", + input: "/ytdlp --extract-audio --audio-format mp3 https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"--extract-audio", "--audio-format", "mp3"}, + }, + { + name: "Multiple URLs with flags", + input: "/ytdlp --format best https://example.com/v1 https://example.com/v2", + expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"}, + expectedFlags: []string{"--format", "best"}, + }, + { + name: "Flags mixed with URLs", + input: "/ytdlp https://example.com/v1 --format best https://example.com/v2", + expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"}, + expectedFlags: []string{"--format", "best"}, + }, + { + name: "Short flag", + input: "/ytdlp -f best https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"-f", "best"}, + }, + { + name: "Boolean flag", + input: "/ytdlp --extract-audio https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"--extract-audio"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := strings.Split(tt.input, " ") + + // Simulate the parsing logic from handleYtdlpCmd + var urls []string + var flags []string + + for i := 1; i < len(args); i++ { + arg := strings.TrimSpace(args[i]) + if arg == "" { + continue + } + + // Check if it's a flag (starts with - or --) + if strings.HasPrefix(arg, "-") { + flags = append(flags, arg) + // Check if the next argument might be a value for this flag + if i+1 < len(args) { + nextArg := strings.TrimSpace(args[i+1]) + if nextArg != "" && !strings.HasPrefix(nextArg, "-") { + // Check if it's clearly a URL (has ://) + if strings.Contains(nextArg, "://") { + // It's a URL, don't consume it as a flag value + continue + } + // Otherwise, treat it as a flag value + flags = append(flags, nextArg) + i++ // Skip the next argument as it's been consumed + } + } + } else { + // Try to parse as URL + u, err := url.Parse(arg) + if err != nil || u.Scheme == "" || u.Host == "" { + continue + } + urls = append(urls, arg) + } + } + + // Verify URLs + if len(urls) != len(tt.expectedURLs) { + t.Errorf("Expected %d URLs, got %d", len(tt.expectedURLs), len(urls)) + } + for i, expectedURL := range tt.expectedURLs { + if i >= len(urls) || urls[i] != expectedURL { + t.Errorf("Expected URL[%d] to be '%s', got '%s'", i, expectedURL, urls[i]) + } + } + + // Verify flags + if len(flags) != len(tt.expectedFlags) { + t.Errorf("Expected %d flags, got %d", len(tt.expectedFlags), len(flags)) + } + for i, expectedFlag := range tt.expectedFlags { + if i >= len(flags) || flags[i] != expectedFlag { + t.Errorf("Expected flag[%d] to be '%s', got '%s'", i, expectedFlag, flags[i]) + } + } + }) + } +} diff --git a/common/i18n/locale/en.yaml b/common/i18n/locale/en.yaml index e8d3282..995e3e9 100644 --- a/common/i18n/locale/en.yaml +++ b/common/i18n/locale/en.yaml @@ -289,7 +289,7 @@ bot: error_no_valid_links: "No valid links to download" info_files_select_storage: "Total {{.Count}} files, please select storage" ytdlp: - usage: "Usage: /ytdlp ..." + usage: "Usage: /ytdlp [OPTIONS] [URL2] ...\nExamples:\n /ytdlp https://example.com/video\n /ytdlp --format best https://example.com/video\n /ytdlp --extract-audio --audio-format mp3 https://example.com/video" error_no_valid_urls: "No valid URLs" info_urls_select_storage: "Found {{.Count}} links, please select storage" info_downloading: "Downloading via yt-dlp..." diff --git a/common/i18n/locale/zh-Hans.yaml b/common/i18n/locale/zh-Hans.yaml index 3d8e0b2..c4218d6 100644 --- a/common/i18n/locale/zh-Hans.yaml +++ b/common/i18n/locale/zh-Hans.yaml @@ -290,7 +290,7 @@ bot: error_no_valid_links: "没有有效的链接可供下载" info_files_select_storage: "共 {{.Count}} 个文件, 请选择存储位置" ytdlp: - usage: "用法: /ytdlp ..." + usage: "用法: /ytdlp [选项] [URL2] ...\n示例:\n /ytdlp https://example.com/video\n /ytdlp --format best https://example.com/video\n /ytdlp --extract-audio --audio-format mp3 https://example.com/video" error_no_valid_urls: "没有有效的 URL" info_urls_select_storage: "共 {{.Count}} 个链接, 请选择存储位置" info_downloading: "正在通过 yt-dlp 下载..." diff --git a/common/utils/ioutil/writer.go b/common/utils/ioutil/writer.go index cdb4ec8..bd4eaa0 100644 --- a/common/utils/ioutil/writer.go +++ b/common/utils/ioutil/writer.go @@ -48,4 +48,4 @@ func NewProgressWriter( wr: wr, onWrite: onWrite, } -} \ No newline at end of file +} diff --git a/core/tasks/ytdlp/execute.go b/core/tasks/ytdlp/execute.go index 0b7012e..20bb362 100644 --- a/core/tasks/ytdlp/execute.go +++ b/core/tasks/ytdlp/execute.go @@ -80,22 +80,34 @@ func (t *Task) Execute(ctx context.Context) error { func (t *Task) downloadFiles(ctx context.Context, tempDir string) ([]string, error) { logger := log.FromContext(ctx) - // Configure yt-dlp command + // Configure yt-dlp command with essential settings + // Always set output path to ensure files go to temp directory cmd := ytdlp.New(). - FormatSort("res,ext:mp4:m4a"). - RecodeVideo("mp4"). - Output(filepath.Join(tempDir, "%(title)s.%(ext)s")). - RestrictFilenames() + Output(filepath.Join(tempDir, "%(title)s.%(ext)s")) + + // If no custom flags are provided, use default behavior + if len(t.Flags) == 0 { + cmd = cmd. + FormatSort("res,ext:mp4:m4a"). + RecodeVideo("mp4"). + RestrictFilenames() + } + // Note: If custom flags are provided, users have full control over format/quality + // The output path is always set above to ensure downloads go to the correct directory if t.Progress != nil { t.Progress.OnProgress(ctx, t, "Downloading...") } - // Execute download with URLs as arguments - logger.Infof("Executing yt-dlp for %d URL(s)", len(t.URLs)) + // Execute download with URLs and custom flags + logger.Infof("Executing yt-dlp for %d URL(s) with %d custom flag(s)", len(t.URLs), len(t.Flags)) + + // Combine flags and URLs as arguments (flags first, then URLs) + // yt-dlp accepts: yt-dlp [OPTIONS] URL [URL...] + args := append(t.Flags, t.URLs...) // Run with context for cancellation support - result, err := cmd.Run(ctx, t.URLs...) + result, err := cmd.Run(ctx, args...) if err != nil { // Check if context was canceled if errors.Is(err, context.Canceled) { diff --git a/core/tasks/ytdlp/task.go b/core/tasks/ytdlp/task.go index a9154e9..ef945c0 100644 --- a/core/tasks/ytdlp/task.go +++ b/core/tasks/ytdlp/task.go @@ -15,6 +15,7 @@ type Task struct { ID string ctx context.Context URLs []string + Flags []string Storage storage.Storage StorPath string Progress ProgressTracker @@ -43,6 +44,7 @@ func NewTask( id string, ctx context.Context, urls []string, + flags []string, stor storage.Storage, storPath string, progressTracker ProgressTracker, @@ -51,6 +53,7 @@ func NewTask( ID: id, ctx: ctx, URLs: urls, + Flags: flags, Storage: stor, StorPath: storPath, Progress: progressTracker, diff --git a/core/tasks/ytdlp/task_test.go b/core/tasks/ytdlp/task_test.go new file mode 100644 index 0000000..dab045e --- /dev/null +++ b/core/tasks/ytdlp/task_test.go @@ -0,0 +1,114 @@ +package ytdlp + +import ( + "context" + "io" + "testing" + + storcfg "github.com/krau/SaveAny-Bot/config/storage" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" +) + +// MockStorage is a simple mock for testing +type MockStorage struct{} + +func (m *MockStorage) Init(ctx context.Context, cfg storcfg.StorageConfig) error { return nil } +func (m *MockStorage) Type() storenum.StorageType { return "mock" } +func (m *MockStorage) Name() string { return "test-storage" } +func (m *MockStorage) JoinStoragePath(p string) string { return "test-path" } +func (m *MockStorage) Save(ctx context.Context, reader io.Reader, path string) error { return nil } +func (m *MockStorage) Exists(ctx context.Context, path string) bool { return false } + +func TestNewTask(t *testing.T) { + ctx := context.Background() + urls := []string{"https://example.com/video"} + flags := []string{"--format", "best"} + stor := &MockStorage{} + storPath := "test-path" + + task := NewTask("test-id", ctx, urls, flags, stor, storPath, nil) + + if task == nil { + t.Fatal("NewTask returned nil") + } + + if task.ID != "test-id" { + t.Errorf("Expected task ID 'test-id', got '%s'", task.ID) + } + + if len(task.URLs) != 1 || task.URLs[0] != "https://example.com/video" { + t.Errorf("Expected URLs to contain 'https://example.com/video', got %v", task.URLs) + } + + if len(task.Flags) != 2 || task.Flags[0] != "--format" || task.Flags[1] != "best" { + t.Errorf("Expected flags to contain '--format' and 'best', got %v", task.Flags) + } + + if task.Storage.Name() != "test-storage" { + t.Errorf("Expected storage name 'test-storage', got '%s'", task.Storage.Name()) + } +} + +func TestNewTaskWithoutFlags(t *testing.T) { + ctx := context.Background() + urls := []string{"https://example.com/video1", "https://example.com/video2"} + var flags []string // No flags + stor := &MockStorage{} + storPath := "test-path" + + task := NewTask("test-id-2", ctx, urls, flags, stor, storPath, nil) + + if task == nil { + t.Fatal("NewTask returned nil") + } + + if len(task.URLs) != 2 { + t.Errorf("Expected 2 URLs, got %d", len(task.URLs)) + } + + if len(task.Flags) != 0 { + t.Errorf("Expected 0 flags, got %d", len(task.Flags)) + } +} + +func TestTaskTitle(t *testing.T) { + ctx := context.Background() + stor := &MockStorage{} + + // Test with single URL + task1 := NewTask("id1", ctx, []string{"https://example.com/video"}, nil, stor, "path", nil) + title1 := task1.Title() + if title1 == "" { + t.Error("Task title should not be empty") + } + + // Test with multiple URLs + task2 := NewTask("id2", ctx, []string{"https://example.com/v1", "https://example.com/v2"}, nil, stor, "path", nil) + title2 := task2.Title() + if title2 == "" { + t.Error("Task title should not be empty") + } +} + +func TestTaskType(t *testing.T) { + ctx := context.Background() + stor := &MockStorage{} + task := NewTask("id", ctx, []string{"https://example.com"}, nil, stor, "path", nil) + + taskType := task.Type() + if taskType.String() != "ytdlp" { + t.Errorf("Expected task type 'ytdlp', got '%s'", taskType.String()) + } +} + +func TestTaskID(t *testing.T) { + ctx := context.Background() + stor := &MockStorage{} + expectedID := "test-task-id-123" + + task := NewTask(expectedID, ctx, []string{"https://example.com"}, nil, stor, "path", nil) + + if task.TaskID() != expectedID { + t.Errorf("Expected task ID '%s', got '%s'", expectedID, task.TaskID()) + } +} diff --git a/database/user.go b/database/user.go index a27cd2c..7f631b5 100644 --- a/database/user.go +++ b/database/user.go @@ -49,4 +49,4 @@ func GetUserByID(ctx context.Context, id uint) (*User, error) { Preload(clause.Associations). Where("id = ?", id).First(&user).Error return &user, err -} \ No newline at end of file +} diff --git a/pkg/enums/ctxkey/context_key.go b/pkg/enums/ctxkey/context_key.go index a56ead8..9696efc 100644 --- a/pkg/enums/ctxkey/context_key.go +++ b/pkg/enums/ctxkey/context_key.go @@ -1,5 +1,6 @@ package ctxkey -//go:generate go-enum --values --names --flag --nocase --noprefix // ENUM(content-length) +// +//go:generate go-enum --values --names --flag --nocase --noprefix type ContextKey string diff --git a/pkg/enums/tasktype/tasktype.go b/pkg/enums/tasktype/tasktype.go index 005ef92..f1248c6 100644 --- a/pkg/enums/tasktype/tasktype.go +++ b/pkg/enums/tasktype/tasktype.go @@ -1,5 +1,6 @@ package tasktype -//go:generate go-enum --values --names --flag --nocase // ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp) +// +//go:generate go-enum --values --names --flag --nocase type TaskType string diff --git a/pkg/tcbdata/data.go b/pkg/tcbdata/data.go index 74aacad..be6faa8 100644 --- a/pkg/tcbdata/data.go +++ b/pkg/tcbdata/data.go @@ -48,7 +48,8 @@ type Add struct { // aria2 Aria2URIs []string // ytdlp - YtdlpURLs []string + YtdlpURLs []string + YtdlpFlags []string } type SetDefaultStorage struct { diff --git a/pkg/tfile/opts.go b/pkg/tfile/opts.go index b6641c4..3478243 100644 --- a/pkg/tfile/opts.go +++ b/pkg/tfile/opts.go @@ -36,4 +36,4 @@ func WithSizeIfZero(size int64) TGFileOption { f.size = size } } -} \ No newline at end of file +}