diff --git a/client/bot/handlers/dl.go b/client/bot/handlers/dl.go index 518886d..d231ec1 100644 --- a/client/bot/handlers/dl.go +++ b/client/bot/handlers/dl.go @@ -84,7 +84,7 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error { return nil } logger.Debug("Preparing aria2 download", "links", links) - + // Initialize aria2 client to check connection aria2ClientInitOnce.Do(func() { aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret) diff --git a/client/bot/handlers/media_group.go b/client/bot/handlers/media_group.go index 2b11738..bd3956c 100644 --- a/client/bot/handlers/media_group.go +++ b/client/bot/handlers/media_group.go @@ -114,7 +114,7 @@ func processMediaGroup(ctx *ext.Context, update *ext.Update, groupID int64) { if err != nil { logger.Errorf("Failed to build storage selection keyboard: %s", err) ctx.EditMessage(userId, &tg.MessagesEditMessageRequest{ - ID: msg.ID, + ID: msg.ID, Message: i18n.T(i18nk.BotMsgMediaGroupErrorBuildStorageSelectKeyboardFailed, map[string]any{ "Error": err.Error(), }), diff --git a/client/bot/handlers/tasks.go b/client/bot/handlers/tasks.go index c5b7ca1..1d6a2d2 100644 --- a/client/bot/handlers/tasks.go +++ b/client/bot/handlers/tasks.go @@ -38,7 +38,7 @@ func handleTaskCmd(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } ctx.Reply(update, ext.ReplyTextStyledTextArray([]styling.StyledTextOption{ - styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)), + styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)), styling.Code(taskID), }), nil) default: diff --git a/client/bot/handlers/update.go b/client/bot/handlers/update.go index fae434a..933a9b1 100644 --- a/client/bot/handlers/update.go +++ b/client/bot/handlers/update.go @@ -103,7 +103,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error { return err } ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{ - ID: u.CallbackQuery.GetMsgID(), + ID: u.CallbackQuery.GetMsgID(), Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradingWithVersion, map[string]any{ "Current": config.Version, }), @@ -111,7 +111,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error { latest, err := ghselfupdate.UpdateSelf(currentV, config.GitRepo) if err != nil { ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{ - ID: u.CallbackQuery.GetMsgID(), + ID: u.CallbackQuery.GetMsgID(), Message: i18n.T(i18nk.BotMsgUpdateErrorUpgradeFailed, map[string]any{ "Error": err.Error(), }), @@ -119,7 +119,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error { return dispatcher.EndGroups } ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{ - ID: u.CallbackQuery.GetMsgID(), + ID: u.CallbackQuery.GetMsgID(), Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradeSuccess, map[string]any{ "Version": latest.Version.String(), }), diff --git a/client/bot/handlers/utils/mediautil/media.go b/client/bot/handlers/utils/mediautil/media.go index 8917c52..88eaeb3 100644 --- a/client/bot/handlers/utils/mediautil/media.go +++ b/client/bot/handlers/utils/mediautil/media.go @@ -112,7 +112,7 @@ func BuildFilenameTemplateData(message *tg.Message) map[string]string { }(), MsgRaw: message.GetMessage(), ChatID: func() string { - // 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id, + // 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id, // 无论它是否是从其他来源转发的 if message.GetPost() { peer := message.GetPeerID() diff --git a/client/bot/handlers/utils/msgelem/storage.go b/client/bot/handlers/utils/msgelem/storage.go index 81869e9..71828ba 100644 --- a/client/bot/handlers/utils/msgelem/storage.go +++ b/client/bot/handlers/utils/msgelem/storage.go @@ -50,7 +50,7 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add) DirectLinks: adddata.DirectLinks, - Aria2URIs: adddata.Aria2URIs, + Aria2URIs: adddata.Aria2URIs, YtdlpURLs: adddata.YtdlpURLs, YtdlpFlags: adddata.YtdlpFlags, } diff --git a/client/bot/handlers/utils/shortcut/parsed.go b/client/bot/handlers/utils/shortcut/parsed.go index dc2fe8d..1cdceae 100644 --- a/client/bot/handlers/utils/shortcut/parsed.go +++ b/client/bot/handlers/utils/shortcut/parsed.go @@ -22,7 +22,7 @@ func CreateAndAddParsedTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirP if err := core.AddTask(injectCtx, task); err != nil { log.FromContext(ctx).Errorf("Failed to add task: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: msgID, + ID: msgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), diff --git a/client/bot/handlers/utils/shortcut/tftask.go b/client/bot/handlers/utils/shortcut/tftask.go index c63e8d6..72fa7a0 100644 --- a/client/bot/handlers/utils/shortcut/tftask.go +++ b/client/bot/handlers/utils/shortcut/tftask.go @@ -29,7 +29,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage if err != nil { logger.Errorf("Failed to get user by chat ID: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{ "Error": err.Error(), }), @@ -49,7 +49,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage if err != nil { logger.Errorf("Failed to get storage by user ID and name: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{ "Error": err.Error(), }), @@ -69,7 +69,7 @@ startCreateTask: if err != nil { logger.Errorf("create task failed: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ "Error": err.Error(), }), @@ -79,7 +79,7 @@ startCreateTask: if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("add task failed: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), @@ -103,7 +103,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to get user by chat ID: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{ "Error": err.Error(), }), @@ -142,7 +142,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to get storage by user ID and name: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{ "Error": err.Error(), }), @@ -156,10 +156,10 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to create task element: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, - Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ - "Error": err.Error(), - }), + ID: trackMsgID, + Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ + "Error": err.Error(), + }), }) return dispatcher.EndGroups } @@ -193,7 +193,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err != nil { logger.Errorf("Failed to create task element for album file: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{ "Error": err.Error(), }), @@ -210,7 +210,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st if err := core.AddTask(injectCtx, task); err != nil { logger.Errorf("Failed to add batch task: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), @@ -218,8 +218,8 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st return dispatcher.EndGroups } ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, - Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{ + ID: trackMsgID, + Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{ "Count": len(files), }), ReplyMarkup: nil, diff --git a/client/bot/handlers/utils/shortcut/tphtask.go b/client/bot/handlers/utils/shortcut/tphtask.go index 6245750..02e3fec 100644 --- a/client/bot/handlers/utils/shortcut/tphtask.go +++ b/client/bot/handlers/utils/shortcut/tphtask.go @@ -25,7 +25,7 @@ func CreateAndAddtelegraphWithEdit( pics []string, stor storage.Storage, trackMsgID int) error { - + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) task := tphtask.NewTask(xid.New().String(), injectCtx, @@ -39,7 +39,7 @@ func CreateAndAddtelegraphWithEdit( if err := core.AddTask(injectCtx, task); err != nil { log.FromContext(ctx).Errorf("Failed to add task: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ - ID: trackMsgID, + ID: trackMsgID, Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{ "Error": err.Error(), }), diff --git a/client/bot/handlers/ytdlp.go b/client/bot/handlers/ytdlp.go index 33eb2fd..1c584a1 100644 --- a/client/bot/handlers/ytdlp.go +++ b/client/bot/handlers/ytdlp.go @@ -27,13 +27,13 @@ func handleYtdlpCmd(ctx *ext.Context, update *ext.Update) error { // Separate URLs and flags from arguments var urls []string var flags []string - + for i := 1; i < len(args); i++ { arg := strings.TrimSpace(args[i]) if arg == "" { continue } - + // Check if it's a flag (starts with - or --) if strings.HasPrefix(arg, "-") { flags = append(flags, arg) diff --git a/client/bot/handlers/ytdlp_test.go b/client/bot/handlers/ytdlp_test.go new file mode 100644 index 0000000..f3b5ee6 --- /dev/null +++ b/client/bot/handlers/ytdlp_test.go @@ -0,0 +1,126 @@ +package handlers + +import ( + "net/url" + "strings" + "testing" +) + +// TestYtdlpArgumentParsing tests the URL and flag separation logic +func TestYtdlpArgumentParsing(t *testing.T) { + tests := []struct { + name string + input string + expectedURLs []string + expectedFlags []string + }{ + { + name: "Single URL without flags", + input: "/ytdlp https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{}, + }, + { + name: "Multiple URLs without flags", + input: "/ytdlp https://example.com/v1 https://example.com/v2", + expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"}, + expectedFlags: []string{}, + }, + { + name: "URL with format flag", + input: "/ytdlp --format best https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"--format", "best"}, + }, + { + name: "URL with extract-audio flag", + input: "/ytdlp --extract-audio --audio-format mp3 https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"--extract-audio", "--audio-format", "mp3"}, + }, + { + name: "Multiple URLs with flags", + input: "/ytdlp --format best https://example.com/v1 https://example.com/v2", + expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"}, + expectedFlags: []string{"--format", "best"}, + }, + { + name: "Flags mixed with URLs", + input: "/ytdlp https://example.com/v1 --format best https://example.com/v2", + expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"}, + expectedFlags: []string{"--format", "best"}, + }, + { + name: "Short flag", + input: "/ytdlp -f best https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"-f", "best"}, + }, + { + name: "Boolean flag", + input: "/ytdlp --extract-audio https://example.com/video", + expectedURLs: []string{"https://example.com/video"}, + expectedFlags: []string{"--extract-audio"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := strings.Split(tt.input, " ") + + // Simulate the parsing logic from handleYtdlpCmd + var urls []string + var flags []string + + for i := 1; i < len(args); i++ { + arg := strings.TrimSpace(args[i]) + if arg == "" { + continue + } + + // Check if it's a flag (starts with - or --) + if strings.HasPrefix(arg, "-") { + flags = append(flags, arg) + // Check if the next argument is a value for this flag (not starting with -) + if i+1 < len(args) && !strings.HasPrefix(strings.TrimSpace(args[i+1]), "-") { + nextArg := strings.TrimSpace(args[i+1]) + // Only treat as flag value if it's not a valid URL + u, err := url.Parse(nextArg) + if err != nil || u.Scheme == "" || u.Host == "" { + flags = append(flags, nextArg) + i++ // Skip the next argument as it's been consumed + continue + } + } + } else { + // Try to parse as URL + u, err := url.Parse(arg) + if err != nil || u.Scheme == "" || u.Host == "" { + continue + } + urls = append(urls, arg) + } + } + + // Verify URLs + if len(urls) != len(tt.expectedURLs) { + t.Errorf("Expected %d URLs, got %d", len(tt.expectedURLs), len(urls)) + } + for i, expectedURL := range tt.expectedURLs { + if i >= len(urls) || urls[i] != expectedURL { + t.Errorf("Expected URL[%d] to be '%s', got '%s'", i, expectedURL, urls[i]) + } + } + + // Verify flags + if len(flags) != len(tt.expectedFlags) { + t.Errorf("Expected %d flags, got %d", len(tt.expectedFlags), len(flags)) + } + for i, expectedFlag := range tt.expectedFlags { + if i >= len(flags) || flags[i] != expectedFlag { + t.Errorf("Expected flag[%d] to be '%s', got '%s'", i, expectedFlag, flags[i]) + } + } + }) + } +} diff --git a/common/utils/ioutil/writer.go b/common/utils/ioutil/writer.go index cdb4ec8..bd4eaa0 100644 --- a/common/utils/ioutil/writer.go +++ b/common/utils/ioutil/writer.go @@ -48,4 +48,4 @@ func NewProgressWriter( wr: wr, onWrite: onWrite, } -} \ No newline at end of file +} diff --git a/core/tasks/ytdlp/task_test.go b/core/tasks/ytdlp/task_test.go new file mode 100644 index 0000000..dab045e --- /dev/null +++ b/core/tasks/ytdlp/task_test.go @@ -0,0 +1,114 @@ +package ytdlp + +import ( + "context" + "io" + "testing" + + storcfg "github.com/krau/SaveAny-Bot/config/storage" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" +) + +// MockStorage is a simple mock for testing +type MockStorage struct{} + +func (m *MockStorage) Init(ctx context.Context, cfg storcfg.StorageConfig) error { return nil } +func (m *MockStorage) Type() storenum.StorageType { return "mock" } +func (m *MockStorage) Name() string { return "test-storage" } +func (m *MockStorage) JoinStoragePath(p string) string { return "test-path" } +func (m *MockStorage) Save(ctx context.Context, reader io.Reader, path string) error { return nil } +func (m *MockStorage) Exists(ctx context.Context, path string) bool { return false } + +func TestNewTask(t *testing.T) { + ctx := context.Background() + urls := []string{"https://example.com/video"} + flags := []string{"--format", "best"} + stor := &MockStorage{} + storPath := "test-path" + + task := NewTask("test-id", ctx, urls, flags, stor, storPath, nil) + + if task == nil { + t.Fatal("NewTask returned nil") + } + + if task.ID != "test-id" { + t.Errorf("Expected task ID 'test-id', got '%s'", task.ID) + } + + if len(task.URLs) != 1 || task.URLs[0] != "https://example.com/video" { + t.Errorf("Expected URLs to contain 'https://example.com/video', got %v", task.URLs) + } + + if len(task.Flags) != 2 || task.Flags[0] != "--format" || task.Flags[1] != "best" { + t.Errorf("Expected flags to contain '--format' and 'best', got %v", task.Flags) + } + + if task.Storage.Name() != "test-storage" { + t.Errorf("Expected storage name 'test-storage', got '%s'", task.Storage.Name()) + } +} + +func TestNewTaskWithoutFlags(t *testing.T) { + ctx := context.Background() + urls := []string{"https://example.com/video1", "https://example.com/video2"} + var flags []string // No flags + stor := &MockStorage{} + storPath := "test-path" + + task := NewTask("test-id-2", ctx, urls, flags, stor, storPath, nil) + + if task == nil { + t.Fatal("NewTask returned nil") + } + + if len(task.URLs) != 2 { + t.Errorf("Expected 2 URLs, got %d", len(task.URLs)) + } + + if len(task.Flags) != 0 { + t.Errorf("Expected 0 flags, got %d", len(task.Flags)) + } +} + +func TestTaskTitle(t *testing.T) { + ctx := context.Background() + stor := &MockStorage{} + + // Test with single URL + task1 := NewTask("id1", ctx, []string{"https://example.com/video"}, nil, stor, "path", nil) + title1 := task1.Title() + if title1 == "" { + t.Error("Task title should not be empty") + } + + // Test with multiple URLs + task2 := NewTask("id2", ctx, []string{"https://example.com/v1", "https://example.com/v2"}, nil, stor, "path", nil) + title2 := task2.Title() + if title2 == "" { + t.Error("Task title should not be empty") + } +} + +func TestTaskType(t *testing.T) { + ctx := context.Background() + stor := &MockStorage{} + task := NewTask("id", ctx, []string{"https://example.com"}, nil, stor, "path", nil) + + taskType := task.Type() + if taskType.String() != "ytdlp" { + t.Errorf("Expected task type 'ytdlp', got '%s'", taskType.String()) + } +} + +func TestTaskID(t *testing.T) { + ctx := context.Background() + stor := &MockStorage{} + expectedID := "test-task-id-123" + + task := NewTask(expectedID, ctx, []string{"https://example.com"}, nil, stor, "path", nil) + + if task.TaskID() != expectedID { + t.Errorf("Expected task ID '%s', got '%s'", expectedID, task.TaskID()) + } +} diff --git a/database/user.go b/database/user.go index a27cd2c..7f631b5 100644 --- a/database/user.go +++ b/database/user.go @@ -49,4 +49,4 @@ func GetUserByID(ctx context.Context, id uint) (*User, error) { Preload(clause.Associations). Where("id = ?", id).First(&user).Error return &user, err -} \ No newline at end of file +} diff --git a/pkg/enums/ctxkey/context_key.go b/pkg/enums/ctxkey/context_key.go index a56ead8..9696efc 100644 --- a/pkg/enums/ctxkey/context_key.go +++ b/pkg/enums/ctxkey/context_key.go @@ -1,5 +1,6 @@ package ctxkey -//go:generate go-enum --values --names --flag --nocase --noprefix // ENUM(content-length) +// +//go:generate go-enum --values --names --flag --nocase --noprefix type ContextKey string diff --git a/pkg/enums/tasktype/tasktype.go b/pkg/enums/tasktype/tasktype.go index 005ef92..f1248c6 100644 --- a/pkg/enums/tasktype/tasktype.go +++ b/pkg/enums/tasktype/tasktype.go @@ -1,5 +1,6 @@ package tasktype -//go:generate go-enum --values --names --flag --nocase // ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp) +// +//go:generate go-enum --values --names --flag --nocase type TaskType string diff --git a/pkg/tfile/opts.go b/pkg/tfile/opts.go index b6641c4..3478243 100644 --- a/pkg/tfile/opts.go +++ b/pkg/tfile/opts.go @@ -36,4 +36,4 @@ func WithSizeIfZero(size int64) TGFileOption { f.size = size } } -} \ No newline at end of file +}