From 166c27c70f8d096540047306f106d85515fdc66b Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Sat, 12 Apr 2025 14:27:13 +0800 Subject: [PATCH] feat: automatic file organization based on rules, close #28 --- bot/bot.go | 1 + bot/handle_dir.go | 96 ++++++++++++++++++------------ bot/handle_rule.go | 141 +++++++++++++++++++++++++++++++++++++++++++++ bot/handlers.go | 1 + core/download.go | 11 ++-- core/rule.go | 103 +++++++++++++++++++++++++++++++++ dao/dir.go | 4 ++ dao/rule.go | 22 +++++++ types/types.go | 9 +++ 9 files changed, 345 insertions(+), 43 deletions(-) create mode 100644 bot/handle_rule.go create mode 100644 core/rule.go create mode 100644 dao/rule.go diff --git a/bot/bot.go b/bot/bot.go index 80a6741..97d0170 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -78,6 +78,7 @@ func Init() { {Command: "storage", Description: "设置默认存储端"}, {Command: "save", Description: "保存所回复的文件"}, {Command: "dir", Description: "管理存储文件夹"}, + {Command: "rule", Description: "管理规则"}, }, }) resultChan <- struct { diff --git a/bot/handle_dir.go b/bot/handle_dir.go index 5b25e14..6b7b803 100644 --- a/bot/handle_dir.go +++ b/bot/handle_dir.go @@ -1,6 +1,8 @@ package bot import ( + "fmt" + "strconv" "strings" "github.com/celestix/gotgproto/dispatcher" @@ -11,51 +13,71 @@ import ( "github.com/krau/SaveAny-Bot/storage" ) -func dirCmd(ctx *ext.Context, update *ext.Update) error { - args := strings.Split(strings.TrimPrefix(update.EffectiveMessage.Text, "/dir "), " ") - if len(args) < 3 { - dirs, err := dao.GetUserDirsByChatID(update.GetUserChat().GetID()) - if err != nil { - common.Log.Errorf("获取用户路径失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取用户路径失败"), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextStyledTextArray( - []styling.StyledTextOption{ - styling.Bold("使用方法: /dir <操作> <存储名> <路径>"), - styling.Plain("\n\n可用操作:\n"), - styling.Code("add"), - styling.Plain(" - 添加路径\n"), - styling.Code("del"), - styling.Plain(" - 删除路径\n"), - styling.Plain("\n示例:\n"), - styling.Code("/dir add local1 path/to/dir"), - styling.Plain("\n\n当前已添加的路径:\n"), - styling.Blockquote(func() string { - var sb strings.Builder - for _, dir := range dirs { - sb.WriteString(dir.StorageName) - sb.WriteString(" - ") - sb.WriteString(dir.Path) - sb.WriteString("\n") - } - return sb.String() - }(), true), - }, - ), nil) +func sendDirHelp(ctx *ext.Context, update *ext.Update, userChatID int64) error { + dirs, err := dao.GetUserDirsByChatID(userChatID) + if err != nil { + common.Log.Errorf("获取用户路径失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户路径失败"), nil) return dispatcher.EndGroups } + ctx.Reply(update, ext.ReplyTextStyledTextArray( + []styling.StyledTextOption{ + styling.Bold("使用方法: /dir <操作> <参数...>"), + styling.Plain("\n\n可用操作:\n"), + styling.Code("add"), + styling.Plain(" <存储名> <路径> - 添加路径\n"), + styling.Code("del"), + styling.Plain(" <路径ID> - 删除路径\n"), + styling.Plain("\n添加路径示例:\n"), + styling.Code("/dir add local1 path/to/dir"), + styling.Plain("\n\n删除路径示例:\n"), + styling.Code("/dir del 3"), + styling.Plain("\n\n当前已添加的路径:\n"), + styling.Blockquote(func() string { + var sb strings.Builder + for _, dir := range dirs { + sb.WriteString(fmt.Sprintf("%d: ", dir.ID)) + sb.WriteString(dir.StorageName) + sb.WriteString(" - ") + sb.WriteString(dir.Path) + sb.WriteString("\n") + } + return sb.String() + }(), true), + }, + ), nil) + return dispatcher.EndGroups +} + +func dirCmd(ctx *ext.Context, update *ext.Update) error { + args := strings.Split(update.EffectiveMessage.Text, " ") + if len(args) < 2 { + return sendDirHelp(ctx, update, update.GetUserChat().GetID()) + } user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) if err != nil { common.Log.Errorf("获取用户失败: %s", err) ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) return dispatcher.EndGroups } - switch args[0] { + switch args[1] { case "add": - return addDir(ctx, update, user, args[1], args[2]) + // /dir add local1 path/to/dir + if len(args) < 4 { + return sendDirHelp(ctx, update, update.GetUserChat().GetID()) + } + return addDir(ctx, update, user, args[2], args[3]) case "del": - return delDir(ctx, update, user, args[1], args[2]) + // /dir del 3 + if len(args) < 3 { + return sendDirHelp(ctx, update, update.GetUserChat().GetID()) + } + dirID, err := strconv.Atoi(args[2]) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("路径ID无效"), nil) + return dispatcher.EndGroups + } + return delDir(ctx, update, user, dirID) default: ctx.Reply(update, ext.ReplyTextString("未知操作"), nil) return dispatcher.EndGroups @@ -77,8 +99,8 @@ func addDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, p return dispatcher.EndGroups } -func delDir(ctx *ext.Context, update *ext.Update, user *dao.User, storageName, path string) error { - if err := dao.DeleteDirForUser(user.ID, storageName, path); err != nil { +func delDir(ctx *ext.Context, update *ext.Update, user *dao.User, dirID int) error { + if err := dao.DeleteDirByID(uint(dirID)); err != nil { common.Log.Errorf("删除路径失败: %s", err) ctx.Reply(update, ext.ReplyTextString("删除路径失败"), nil) return dispatcher.EndGroups diff --git a/bot/handle_rule.go b/bot/handle_rule.go new file mode 100644 index 0000000..d2b8100 --- /dev/null +++ b/bot/handle_rule.go @@ -0,0 +1,141 @@ +package bot + +import ( + "fmt" + "strconv" + "strings" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/duke-git/lancet/v2/slice" + "github.com/gotd/td/telegram/message/styling" + "github.com/krau/SaveAny-Bot/common" + "github.com/krau/SaveAny-Bot/dao" + "github.com/krau/SaveAny-Bot/types" +) + +func sendRuleHelp(ctx *ext.Context, update *ext.Update, userChatID int64) error { + user, err := dao.GetUserByChatID(userChatID) + if err != nil { + common.Log.Errorf("获取用户规则失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户规则失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextStyledTextArray( + []styling.StyledTextOption{ + styling.Bold("使用方法: /rule <操作> <参数...>"), + styling.Bold(fmt.Sprintf("\n当前已%s规则模式", map[bool]string{true: "启用", false: "禁用"}[user.ApplyRule])), + styling.Plain("\n\n可用操作:\n"), + styling.Code("switch"), + styling.Plain(" - 开关规则模式\n"), + styling.Code("add"), + styling.Plain(" <类型> <数据> <存储名> <路径> - 添加规则\n"), + styling.Code("del"), + styling.Plain(" <规则ID> - 删除规则\n"), + styling.Plain("\n当前已添加的规则:\n"), + styling.Blockquote(func() string { + var sb strings.Builder + for _, rule := range user.Rules { + ruleText := fmt.Sprintf("%s %s %s %s", rule.Type, rule.Data, rule.StorageName, rule.DirPath) + sb.WriteString(fmt.Sprintf("%d: %s\n", rule.ID, ruleText)) + } + return sb.String() + }(), true), + }, + ), nil) + return dispatcher.EndGroups +} + +func ruleCmd(ctx *ext.Context, update *ext.Update) error { + args := strings.Split(update.EffectiveMessage.Text, " ") + if len(args) < 2 { + return sendRuleHelp(ctx, update, update.GetUserChat().GetID()) + } + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) + if err != nil { + common.Log.Errorf("获取用户失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) + return dispatcher.EndGroups + } + switch args[1] { + case "switch": + // /rule switch + return switchApplyRule(ctx, update, user) + case "add": + // /rule add + if len(args) < 6 { + return sendRuleHelp(ctx, update, user.ChatID) + } + return addRule(ctx, update, user, args) + case "del": + // /rule del + if len(args) < 3 { + return sendRuleHelp(ctx, update, user.ChatID) + } + ruleID := args[2] + id, err := strconv.Atoi(ruleID) + if err != nil { + ctx.Reply(update, ext.ReplyTextString("无效的规则ID"), nil) + return dispatcher.EndGroups + } + if err := dao.DeleteRule(uint(id)); err != nil { + common.Log.Errorf("删除规则失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("删除规则失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextString("删除规则成功"), nil) + return dispatcher.EndGroups + default: + return sendRuleHelp(ctx, update, user.ChatID) + } +} + +func switchApplyRule(ctx *ext.Context, update *ext.Update, user *dao.User) error { + applyRule := !user.ApplyRule + if err := dao.UpdateUserApplyRule(user.ChatID, applyRule); err != nil { + common.Log.Errorf("更新用户失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("更新用户失败"), nil) + return dispatcher.EndGroups + } + if applyRule { + ctx.Reply(update, ext.ReplyTextString("已启用规则模式"), nil) + } else { + ctx.Reply(update, ext.ReplyTextString("已禁用规则模式"), nil) + } + return dispatcher.EndGroups +} + +func addRule(ctx *ext.Context, update *ext.Update, user *dao.User, args []string) error { + // /rule add + ruleType := args[2] + ruleData := args[3] + storageName := args[4] + dirPath := args[5] + + if !slice.Contain(types.RuleTypes, types.RuleType(ruleType)) { + var ruleTypesStylingArray []styling.StyledTextOption + ruleTypesStylingArray = append(ruleTypesStylingArray, styling.Bold("无效的规则类型, 可用类型:\n")) + for i, ruleType := range types.RuleTypes { + ruleTypesStylingArray = append(ruleTypesStylingArray, styling.Code(string(ruleType))) + if i != len(types.RuleTypes)-1 { + ruleTypesStylingArray = append(ruleTypesStylingArray, styling.Plain(", ")) + } + } + ctx.Reply(update, ext.ReplyTextStyledTextArray(ruleTypesStylingArray), nil) + return dispatcher.EndGroups + } + rule := &dao.Rule{ + Type: ruleType, + Data: ruleData, + StorageName: storageName, + DirPath: dirPath, + UserID: user.ID, + } + if err := dao.CreateRule(rule); err != nil { + common.Log.Errorf("添加规则失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("添加规则失败"), nil) + return dispatcher.EndGroups + } + ctx.Reply(update, ext.ReplyTextString("添加规则成功"), nil) + return dispatcher.EndGroups +} diff --git a/bot/handlers.go b/bot/handlers.go index 410d363..8d87701 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -15,6 +15,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) { dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd)) dispatcher.AddHandler(handlers.NewCommand("save", saveCmd)) dispatcher.AddHandler(handlers.NewCommand("dir", dirCmd)) + dispatcher.AddHandler(handlers.NewCommand("rule", ruleCmd)) linkRegexFilter, err := filters.Message.Regex(linkRegexString) if err != nil { common.Log.Panicf("创建正则表达式过滤器失败: %s", err) diff --git a/core/download.go b/core/download.go index 61b6c11..d92db33 100644 --- a/core/download.go +++ b/core/download.go @@ -31,15 +31,14 @@ func processPendingTask(task *types.Task) error { task.File.FileName = fmt.Sprintf("%d_%d_%s", task.FileChatID, task.FileMessageID, task.File.Hash()) } - if task.StoragePath == "" { - task.StoragePath = task.FileName() - } - - taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName) + taskStorage, storagePath, err := getStorageAndPathForTask(task) if err != nil { return err } - task.StoragePath = taskStorage.JoinStoragePath(*task) + if taskStorage == nil { + return fmt.Errorf("not found storage: %s", task.StorageName) + } + task.StoragePath = storagePath ctx, ok := task.Ctx.(*ext.Context) if !ok { diff --git a/core/rule.go b/core/rule.go new file mode 100644 index 0000000..9193e55 --- /dev/null +++ b/core/rule.go @@ -0,0 +1,103 @@ +package core + +import ( + "fmt" + "path" + "regexp" + + "github.com/celestix/gotgproto/ext" + "github.com/krau/SaveAny-Bot/bot" + "github.com/krau/SaveAny-Bot/common" + "github.com/krau/SaveAny-Bot/dao" + "github.com/krau/SaveAny-Bot/storage" + "github.com/krau/SaveAny-Bot/types" +) + +func getStorageAndPathForTask(task *types.Task) (storage.Storage, string, error) { + user, err := dao.GetUserByChatID(task.UserID) + if err != nil { + return nil, "", fmt.Errorf("failed to get user by chat ID: %w", err) + } + if task.StoragePath == "" { + task.StoragePath = task.FileName() + } + taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName) + if err != nil { + return nil, "", err + } + storagePath := taskStorage.JoinStoragePath(*task) + if !user.ApplyRule || user.Rules == nil { + return taskStorage, storagePath, nil + } + var ruleTaskStorage storage.Storage + var ruleStoragePath string + for _, rule := range user.Rules { + matchStorage, matchStoragePath := applyRule(&rule, *task) + if matchStorage != nil && matchStoragePath != "" { + ruleTaskStorage = matchStorage + ruleStoragePath = matchStoragePath + } + } + if ruleStoragePath == "" || ruleTaskStorage == nil { + return taskStorage, storagePath, nil + } + common.Log.Debugf("Rule matched: %s, %s", ruleTaskStorage.Name(), ruleStoragePath) + return ruleTaskStorage, ruleStoragePath, nil +} + +func applyRule(rule *dao.Rule, task types.Task) (storage.Storage, string) { + var DirPath, StorageName string + switch rule.Type { + case string(types.RuleTypeFileNameRegex): + ruleRegex, err := regexp.Compile(rule.Data) + if err != nil { + common.Log.Errorf("failed to compile regex: %s", err) + return nil, "" + } + if !ruleRegex.MatchString(task.FileName()) { + return nil, "" + } + DirPath = rule.DirPath + StorageName = rule.StorageName + case string(types.RuleTypeMessageRegex): + ruleRegex, err := regexp.Compile(rule.Data) + if err != nil { + common.Log.Errorf("failed to compile regex: %s", err) + return nil, "" + } + ctx, ok := task.Ctx.(*ext.Context) + if !ok { + common.Log.Fatalf("context is not *ext.Context: %T", task.Ctx) + return nil, "" + } + msg, err := bot.GetTGMessage(ctx, task.FileChatID, task.FileMessageID) + if err != nil { + common.Log.Errorf("failed to get message: %s", err) + return nil, "" + } + if msg == nil { + return nil, "" + } + if !ruleRegex.MatchString(msg.GetMessage()) { + return nil, "" + } + DirPath = rule.DirPath + StorageName = rule.StorageName + default: + common.Log.Errorf("unknown rule type: %s", rule.Type) + return nil, "" + } + taskStorageName := func() string { + if StorageName == "" || StorageName == "CHOSEN" { + return task.StorageName + } + return StorageName + }() + taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, taskStorageName) + if err != nil { + common.Log.Errorf("failed to get storage: %s", err) + return nil, "" + } + task.StoragePath = path.Join(DirPath, task.StoragePath) + return taskStorage, taskStorage.JoinStoragePath(task) +} diff --git a/dao/dir.go b/dao/dir.go index 17b8490..6b5fcb2 100644 --- a/dao/dir.go +++ b/dao/dir.go @@ -41,3 +41,7 @@ func GetDirsByUserIDAndStorageName(userID uint, storageName string) ([]Dir, erro func DeleteDirForUser(userID uint, storageName, path string) error { return db.Unscoped().Where("user_id = ? AND storage_name = ? AND path = ?", userID, storageName, path).Delete(&Dir{}).Error } + +func DeleteDirByID(id uint) error { + return db.Unscoped().Delete(&Dir{}, id).Error +} \ No newline at end of file diff --git a/dao/rule.go b/dao/rule.go new file mode 100644 index 0000000..452c5b2 --- /dev/null +++ b/dao/rule.go @@ -0,0 +1,22 @@ +package dao + +func CreateRule(rule *Rule) error { + return db.Create(rule).Error +} + +func DeleteRule(ruleID uint) error { + return db.Unscoped().Delete(&Rule{}, ruleID).Error +} + +func UpdateUserApplyRule(chatID int64, applyRule bool) error { + return db.Model(&User{}).Where("chat_id = ?", chatID).Update("apply_rule", applyRule).Error +} + +func GetRulesByUserChatID(chatID int64) ([]Rule, error) { + var rules []Rule + err := db.Where("user_id = (SELECT id FROM users WHERE chat_id = ?)", chatID).Find(&rules).Error + if err != nil { + return nil, err + } + return rules, nil +} diff --git a/types/types.go b/types/types.go index 25d4d65..4e11bbf 100644 --- a/types/types.go +++ b/types/types.go @@ -31,3 +31,12 @@ type ContextKey string const ( ContextKeyContentLength ContextKey = "content-length" ) + +type RuleType string + +const ( + RuleTypeFileNameRegex RuleType = "FILENAME-REGEX" + RuleTypeMessageRegex RuleType = "MESSAGE-REGEX" +) + +var RuleTypes = []RuleType{RuleTypeFileNameRegex, RuleTypeMessageRegex} \ No newline at end of file