feat: add custom parameter support to /ytdlp command (#185), close #184

* Initial plan

* Implement parameter support for /ytdlp command

Co-authored-by: krau <71133316+krau@users.noreply.github.com>

* Add comprehensive tests for ytdlp parameter parsing

Co-authored-by: krau <71133316+krau@users.noreply.github.com>

* Improve flag parsing logic and clarify argument order

Co-authored-by: krau <71133316+krau@users.noreply.github.com>

* Preserve critical defaults and improve comments

Co-authored-by: krau <71133316+krau@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
This commit is contained in:
Copilot
2026-01-19 13:10:21 +08:00
committed by GitHub
parent 3ce00884a0
commit 3e20dc2c5f
24 changed files with 350 additions and 58 deletions

View File

@@ -100,7 +100,7 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error {
}
shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID)
case tasktype.TaskTypeYtdlp:
shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, msgID, userID)
shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, data.YtdlpFlags, msgID, userID)
default:
return fmt.Errorf("unexcept task type: %s", data.TaskType)
}

View File

@@ -84,7 +84,7 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
return nil
}
logger.Debug("Preparing aria2 download", "links", links)
// Initialize aria2 client to check connection
aria2ClientInitOnce.Do(func() {
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)

View File

@@ -114,7 +114,7 @@ func processMediaGroup(ctx *ext.Context, update *ext.Update, groupID int64) {
if err != nil {
logger.Errorf("Failed to build storage selection keyboard: %s", err)
ctx.EditMessage(userId, &tg.MessagesEditMessageRequest{
ID: msg.ID,
ID: msg.ID,
Message: i18n.T(i18nk.BotMsgMediaGroupErrorBuildStorageSelectKeyboardFailed, map[string]any{
"Error": err.Error(),
}),

View File

@@ -38,7 +38,7 @@ func handleTaskCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextStyledTextArray([]styling.StyledTextOption{
styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)),
styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)),
styling.Code(taskID),
}), nil)
default:

View File

@@ -103,7 +103,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error {
return err
}
ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{
ID: u.CallbackQuery.GetMsgID(),
ID: u.CallbackQuery.GetMsgID(),
Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradingWithVersion, map[string]any{
"Current": config.Version,
}),
@@ -111,7 +111,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error {
latest, err := ghselfupdate.UpdateSelf(currentV, config.GitRepo)
if err != nil {
ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{
ID: u.CallbackQuery.GetMsgID(),
ID: u.CallbackQuery.GetMsgID(),
Message: i18n.T(i18nk.BotMsgUpdateErrorUpgradeFailed, map[string]any{
"Error": err.Error(),
}),
@@ -119,7 +119,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error {
return dispatcher.EndGroups
}
ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{
ID: u.CallbackQuery.GetMsgID(),
ID: u.CallbackQuery.GetMsgID(),
Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradeSuccess, map[string]any{
"Version": latest.Version.String(),
}),

View File

@@ -112,7 +112,7 @@ func BuildFilenameTemplateData(message *tg.Message) map[string]string {
}(),
MsgRaw: message.GetMessage(),
ChatID: func() string {
// 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id,
// 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id,
// 无论它是否是从其他来源转发的
if message.GetPost() {
peer := message.GetPeerID()

View File

@@ -50,8 +50,9 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add)
DirectLinks: adddata.DirectLinks,
Aria2URIs: adddata.Aria2URIs,
YtdlpURLs: adddata.YtdlpURLs,
Aria2URIs: adddata.Aria2URIs,
YtdlpURLs: adddata.YtdlpURLs,
YtdlpFlags: adddata.YtdlpFlags,
}
dataid := xid.New().String()
err := cache.Set(dataid, data)

View File

@@ -22,7 +22,7 @@ func CreateAndAddParsedTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirP
if err := core.AddTask(injectCtx, task); err != nil {
log.FromContext(ctx).Errorf("Failed to add task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),

View File

@@ -29,7 +29,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage
if err != nil {
logger.Errorf("Failed to get user by chat ID: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{
"Error": err.Error(),
}),
@@ -49,7 +49,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage
if err != nil {
logger.Errorf("Failed to get storage by user ID and name: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{
"Error": err.Error(),
}),
@@ -69,7 +69,7 @@ startCreateTask:
if err != nil {
logger.Errorf("create task failed: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
@@ -79,7 +79,7 @@ startCreateTask:
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("add task failed: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),
@@ -103,7 +103,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to get user by chat ID: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{
"Error": err.Error(),
}),
@@ -142,7 +142,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to get storage by user ID and name: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{
"Error": err.Error(),
}),
@@ -156,10 +156,10 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to create task element: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
})
return dispatcher.EndGroups
}
@@ -193,7 +193,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to create task element for album file: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
@@ -210,7 +210,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add batch task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),
@@ -218,8 +218,8 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
return dispatcher.EndGroups
}
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{
"Count": len(files),
}),
ReplyMarkup: nil,

View File

@@ -25,7 +25,7 @@ func CreateAndAddtelegraphWithEdit(
pics []string,
stor storage.Storage,
trackMsgID int) error {
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
task := tphtask.NewTask(xid.New().String(),
injectCtx,
@@ -39,7 +39,7 @@ func CreateAndAddtelegraphWithEdit(
if err := core.AddTask(injectCtx, task); err != nil {
log.FromContext(ctx).Errorf("Failed to add task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),

View File

@@ -15,7 +15,7 @@ import (
"github.com/krau/SaveAny-Bot/storage"
)
func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, msgID int, userID int64) error {
func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, flags []string, msgID int, userID int64) error {
logger := log.FromContext(ctx)
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
@@ -29,13 +29,14 @@ func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPa
return dispatcher.EndGroups
}
logger.Infof("Creating yt-dlp task for %d URL(s)", len(urls))
logger.Infof("Creating yt-dlp task for %d URL(s) with %d flag(s)", len(urls), len(flags))
// Create yt-dlp task
task := ytdlp.NewTask(
xid.New().String(),
injectCtx,
urls,
flags,
stor,
stor.JoinStoragePath(dirPath),
ytdlp.NewProgress(msgID, userID),

View File

@@ -7,7 +7,6 @@ import (
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/duke-git/lancet/v2/slice"
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem"
"github.com/krau/SaveAny-Bot/common/i18n"
@@ -25,29 +24,59 @@ func handleYtdlpCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
urls := args[1:]
// Validate and clean URLs
for i, link := range urls {
urls[i] = strings.TrimSpace(link)
u, err := url.Parse(link)
if err != nil || u.Scheme == "" || u.Host == "" {
logger.Warnf("Invalid URL: %s", link)
urls[i] = ""
// Separate URLs and flags from arguments
var urls []string
var flags []string
for i := 1; i < len(args); i++ {
arg := strings.TrimSpace(args[i])
if arg == "" {
continue
}
// Check if it's a flag (starts with - or --)
if strings.HasPrefix(arg, "-") {
flags = append(flags, arg)
// Check if the next argument might be a value for this flag
// Don't consume it if it starts with - or looks like a URL with scheme
if i+1 < len(args) {
nextArg := strings.TrimSpace(args[i+1])
if nextArg != "" && !strings.HasPrefix(nextArg, "-") {
// Check if it's clearly a URL (has ://)
// This handles common video URLs (http://, https://)
// For other yt-dlp inputs, users should ensure proper formatting
if strings.Contains(nextArg, "://") {
// It's a URL, don't consume it as a flag value
continue
}
// Otherwise, treat it as a flag value
flags = append(flags, nextArg)
i++ // Skip the next argument as it's been consumed
}
}
} else {
// Try to parse as URL
u, err := url.Parse(arg)
if err != nil || u.Scheme == "" || u.Host == "" {
logger.Warnf("Invalid URL: %s", arg)
continue
}
urls = append(urls, arg)
}
}
urls = slice.Compact(urls)
if len(urls) == 0 {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls)), nil)
return dispatcher.EndGroups
}
logger.Debugf("Preparing yt-dlp download for %d URL(s)", len(urls))
logger.Debugf("Preparing yt-dlp download for %d URL(s) with %d flag(s)", len(urls), len(flags))
// Build storage selection keyboard
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
TaskType: tasktype.TaskTypeYtdlp,
YtdlpURLs: urls,
TaskType: tasktype.TaskTypeYtdlp,
YtdlpURLs: urls,
YtdlpFlags: flags,
})
if err != nil {
return err

View File

@@ -0,0 +1,129 @@
package handlers
import (
"net/url"
"strings"
"testing"
)
// TestYtdlpArgumentParsing tests the URL and flag separation logic
func TestYtdlpArgumentParsing(t *testing.T) {
tests := []struct {
name string
input string
expectedURLs []string
expectedFlags []string
}{
{
name: "Single URL without flags",
input: "/ytdlp https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{},
},
{
name: "Multiple URLs without flags",
input: "/ytdlp https://example.com/v1 https://example.com/v2",
expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"},
expectedFlags: []string{},
},
{
name: "URL with format flag",
input: "/ytdlp --format best https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"--format", "best"},
},
{
name: "URL with extract-audio flag",
input: "/ytdlp --extract-audio --audio-format mp3 https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"--extract-audio", "--audio-format", "mp3"},
},
{
name: "Multiple URLs with flags",
input: "/ytdlp --format best https://example.com/v1 https://example.com/v2",
expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"},
expectedFlags: []string{"--format", "best"},
},
{
name: "Flags mixed with URLs",
input: "/ytdlp https://example.com/v1 --format best https://example.com/v2",
expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"},
expectedFlags: []string{"--format", "best"},
},
{
name: "Short flag",
input: "/ytdlp -f best https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"-f", "best"},
},
{
name: "Boolean flag",
input: "/ytdlp --extract-audio https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"--extract-audio"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args := strings.Split(tt.input, " ")
// Simulate the parsing logic from handleYtdlpCmd
var urls []string
var flags []string
for i := 1; i < len(args); i++ {
arg := strings.TrimSpace(args[i])
if arg == "" {
continue
}
// Check if it's a flag (starts with - or --)
if strings.HasPrefix(arg, "-") {
flags = append(flags, arg)
// Check if the next argument might be a value for this flag
if i+1 < len(args) {
nextArg := strings.TrimSpace(args[i+1])
if nextArg != "" && !strings.HasPrefix(nextArg, "-") {
// Check if it's clearly a URL (has ://)
if strings.Contains(nextArg, "://") {
// It's a URL, don't consume it as a flag value
continue
}
// Otherwise, treat it as a flag value
flags = append(flags, nextArg)
i++ // Skip the next argument as it's been consumed
}
}
} else {
// Try to parse as URL
u, err := url.Parse(arg)
if err != nil || u.Scheme == "" || u.Host == "" {
continue
}
urls = append(urls, arg)
}
}
// Verify URLs
if len(urls) != len(tt.expectedURLs) {
t.Errorf("Expected %d URLs, got %d", len(tt.expectedURLs), len(urls))
}
for i, expectedURL := range tt.expectedURLs {
if i >= len(urls) || urls[i] != expectedURL {
t.Errorf("Expected URL[%d] to be '%s', got '%s'", i, expectedURL, urls[i])
}
}
// Verify flags
if len(flags) != len(tt.expectedFlags) {
t.Errorf("Expected %d flags, got %d", len(tt.expectedFlags), len(flags))
}
for i, expectedFlag := range tt.expectedFlags {
if i >= len(flags) || flags[i] != expectedFlag {
t.Errorf("Expected flag[%d] to be '%s', got '%s'", i, expectedFlag, flags[i])
}
}
})
}
}