diff --git a/client/bot/handlers/add_task.go b/client/bot/handlers/add_task.go index 858436c..a29886a 100644 --- a/client/bot/handlers/add_task.go +++ b/client/bot/handlers/add_task.go @@ -100,7 +100,7 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error { } shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID) case tasktype.TaskTypeYtdlp: - shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, msgID, userID) + shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, data.YtdlpFlags, msgID, userID) default: return fmt.Errorf("unexcept task type: %s", data.TaskType) } diff --git a/client/bot/handlers/utils/msgelem/storage.go b/client/bot/handlers/utils/msgelem/storage.go index e4edae0..81869e9 100644 --- a/client/bot/handlers/utils/msgelem/storage.go +++ b/client/bot/handlers/utils/msgelem/storage.go @@ -51,7 +51,8 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add) DirectLinks: adddata.DirectLinks, Aria2URIs: adddata.Aria2URIs, - YtdlpURLs: adddata.YtdlpURLs, + YtdlpURLs: adddata.YtdlpURLs, + YtdlpFlags: adddata.YtdlpFlags, } dataid := xid.New().String() err := cache.Set(dataid, data) diff --git a/client/bot/handlers/utils/shortcut/ytdlp.go b/client/bot/handlers/utils/shortcut/ytdlp.go index 9cdefbd..038ece7 100644 --- a/client/bot/handlers/utils/shortcut/ytdlp.go +++ b/client/bot/handlers/utils/shortcut/ytdlp.go @@ -15,7 +15,7 @@ import ( "github.com/krau/SaveAny-Bot/storage" ) -func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, msgID int, userID int64) error { +func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, flags []string, msgID int, userID int64) error { logger := log.FromContext(ctx) injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) @@ -29,13 +29,14 @@ func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPa return dispatcher.EndGroups } - logger.Infof("Creating yt-dlp task for %d URL(s)", len(urls)) + logger.Infof("Creating yt-dlp task for %d URL(s) with %d flag(s)", len(urls), len(flags)) // Create yt-dlp task task := ytdlp.NewTask( xid.New().String(), injectCtx, urls, + flags, stor, stor.JoinStoragePath(dirPath), ytdlp.NewProgress(msgID, userID), diff --git a/client/bot/handlers/ytdlp.go b/client/bot/handlers/ytdlp.go index 614e8d8..33eb2fd 100644 --- a/client/bot/handlers/ytdlp.go +++ b/client/bot/handlers/ytdlp.go @@ -7,7 +7,6 @@ import ( "github.com/celestix/gotgproto/dispatcher" "github.com/celestix/gotgproto/ext" "github.com/charmbracelet/log" - "github.com/duke-git/lancet/v2/slice" "github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem" "github.com/krau/SaveAny-Bot/common/i18n" @@ -25,29 +24,53 @@ func handleYtdlpCmd(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } - urls := args[1:] - // Validate and clean URLs - for i, link := range urls { - urls[i] = strings.TrimSpace(link) - u, err := url.Parse(link) - if err != nil || u.Scheme == "" || u.Host == "" { - logger.Warnf("Invalid URL: %s", link) - urls[i] = "" + // Separate URLs and flags from arguments + var urls []string + var flags []string + + for i := 1; i < len(args); i++ { + arg := strings.TrimSpace(args[i]) + if arg == "" { + continue + } + + // Check if it's a flag (starts with - or --) + if strings.HasPrefix(arg, "-") { + flags = append(flags, arg) + // Check if the next argument is a value for this flag (not starting with -) + if i+1 < len(args) && !strings.HasPrefix(strings.TrimSpace(args[i+1]), "-") { + nextArg := strings.TrimSpace(args[i+1]) + // Only treat as flag value if it's not a valid URL + u, err := url.Parse(nextArg) + if err != nil || u.Scheme == "" || u.Host == "" { + flags = append(flags, nextArg) + i++ // Skip the next argument as it's been consumed + continue + } + } + } else { + // Try to parse as URL + u, err := url.Parse(arg) + if err != nil || u.Scheme == "" || u.Host == "" { + logger.Warnf("Invalid URL: %s", arg) + continue + } + urls = append(urls, arg) } } - urls = slice.Compact(urls) if len(urls) == 0 { ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls)), nil) return dispatcher.EndGroups } - logger.Debugf("Preparing yt-dlp download for %d URL(s)", len(urls)) + logger.Debugf("Preparing yt-dlp download for %d URL(s) with %d flag(s)", len(urls), len(flags)) // Build storage selection keyboard markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{ - TaskType: tasktype.TaskTypeYtdlp, - YtdlpURLs: urls, + TaskType: tasktype.TaskTypeYtdlp, + YtdlpURLs: urls, + YtdlpFlags: flags, }) if err != nil { return err diff --git a/common/i18n/locale/en.yaml b/common/i18n/locale/en.yaml index e8d3282..995e3e9 100644 --- a/common/i18n/locale/en.yaml +++ b/common/i18n/locale/en.yaml @@ -289,7 +289,7 @@ bot: error_no_valid_links: "No valid links to download" info_files_select_storage: "Total {{.Count}} files, please select storage" ytdlp: - usage: "Usage: /ytdlp ..." + usage: "Usage: /ytdlp [OPTIONS] [URL2] ...\nExamples:\n /ytdlp https://example.com/video\n /ytdlp --format best https://example.com/video\n /ytdlp --extract-audio --audio-format mp3 https://example.com/video" error_no_valid_urls: "No valid URLs" info_urls_select_storage: "Found {{.Count}} links, please select storage" info_downloading: "Downloading via yt-dlp..." diff --git a/common/i18n/locale/zh-Hans.yaml b/common/i18n/locale/zh-Hans.yaml index 3d8e0b2..c4218d6 100644 --- a/common/i18n/locale/zh-Hans.yaml +++ b/common/i18n/locale/zh-Hans.yaml @@ -290,7 +290,7 @@ bot: error_no_valid_links: "没有有效的链接可供下载" info_files_select_storage: "共 {{.Count}} 个文件, 请选择存储位置" ytdlp: - usage: "用法: /ytdlp ..." + usage: "用法: /ytdlp [选项] [URL2] ...\n示例:\n /ytdlp https://example.com/video\n /ytdlp --format best https://example.com/video\n /ytdlp --extract-audio --audio-format mp3 https://example.com/video" error_no_valid_urls: "没有有效的 URL" info_urls_select_storage: "共 {{.Count}} 个链接, 请选择存储位置" info_downloading: "正在通过 yt-dlp 下载..." diff --git a/core/tasks/ytdlp/execute.go b/core/tasks/ytdlp/execute.go index 0b7012e..af942d2 100644 --- a/core/tasks/ytdlp/execute.go +++ b/core/tasks/ytdlp/execute.go @@ -80,22 +80,31 @@ func (t *Task) Execute(ctx context.Context) error { func (t *Task) downloadFiles(ctx context.Context, tempDir string) ([]string, error) { logger := log.FromContext(ctx) - // Configure yt-dlp command + // Configure yt-dlp command with default settings cmd := ytdlp.New(). - FormatSort("res,ext:mp4:m4a"). - RecodeVideo("mp4"). - Output(filepath.Join(tempDir, "%(title)s.%(ext)s")). - RestrictFilenames() + Output(filepath.Join(tempDir, "%(title)s.%(ext)s")) + + // If no custom flags are provided, use default behavior + if len(t.Flags) == 0 { + cmd = cmd. + FormatSort("res,ext:mp4:m4a"). + RecodeVideo("mp4"). + RestrictFilenames() + } if t.Progress != nil { t.Progress.OnProgress(ctx, t, "Downloading...") } - // Execute download with URLs as arguments - logger.Infof("Executing yt-dlp for %d URL(s)", len(t.URLs)) + // Execute download with URLs and custom flags + logger.Infof("Executing yt-dlp for %d URL(s) with %d custom flag(s)", len(t.URLs), len(t.Flags)) + + // Combine URLs and flags as arguments + // The Run method will pass flags as raw command-line arguments + args := append(t.Flags, t.URLs...) // Run with context for cancellation support - result, err := cmd.Run(ctx, t.URLs...) + result, err := cmd.Run(ctx, args...) if err != nil { // Check if context was canceled if errors.Is(err, context.Canceled) { diff --git a/core/tasks/ytdlp/task.go b/core/tasks/ytdlp/task.go index a9154e9..ef945c0 100644 --- a/core/tasks/ytdlp/task.go +++ b/core/tasks/ytdlp/task.go @@ -15,6 +15,7 @@ type Task struct { ID string ctx context.Context URLs []string + Flags []string Storage storage.Storage StorPath string Progress ProgressTracker @@ -43,6 +44,7 @@ func NewTask( id string, ctx context.Context, urls []string, + flags []string, stor storage.Storage, storPath string, progressTracker ProgressTracker, @@ -51,6 +53,7 @@ func NewTask( ID: id, ctx: ctx, URLs: urls, + Flags: flags, Storage: stor, StorPath: storPath, Progress: progressTracker, diff --git a/pkg/tcbdata/data.go b/pkg/tcbdata/data.go index 74aacad..be6faa8 100644 --- a/pkg/tcbdata/data.go +++ b/pkg/tcbdata/data.go @@ -48,7 +48,8 @@ type Add struct { // aria2 Aria2URIs []string // ytdlp - YtdlpURLs []string + YtdlpURLs []string + YtdlpFlags []string } type SetDefaultStorage struct {