From c4eb824457e49952b4f6e98c57a1b0a68df0dcd6 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:23:12 +0800 Subject: [PATCH] feat: set default storage by inline keyboard --- bootstrap/init.go | 3 +- bot/handlers.go | 80 ++++++++++++++++++++++++++++++++++++++-- bot/utils.go | 20 ++++++++++ config/user.go | 21 +++++++++++ config/viper.go | 10 +++++ core/core.go | 2 +- storage/alist/alist.go | 9 ----- storage/local/local.go | 4 -- storage/storage.go | 31 +++++++++++----- storage/webdav/webdav.go | 2 - types/types.go | 7 +++- 11 files changed, 157 insertions(+), 32 deletions(-) diff --git a/bootstrap/init.go b/bootstrap/init.go index 4fa876d..94c323f 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -9,6 +9,7 @@ import ( "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/logger" + "github.com/krau/SaveAny-Bot/storage" ) func InitAll() { @@ -18,7 +19,7 @@ func InitAll() { } logger.InitLogger() logger.L.Info("Starting SaveAny-Bot...") - + storage.LoadStorages() common.Init() dao.Init() bot.Init() diff --git a/bot/handlers.go b/bot/handlers.go index 7637768..7ef06ef 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -35,8 +35,8 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) { } dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue)) + dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage)) dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage)) - // dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation)) } const noPermissionText string = ` @@ -69,7 +69,6 @@ Save Any Bot - 转存你的 Telegram 文件 /silent - 开关静默模式 /storage - 设置默认存储位置 /save [自定义文件名] - 保存文件 -/path <存储类型> <路径> - 更改文件保存路径 静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问 @@ -196,11 +195,82 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error { ReplyMessageID: replied.ID, ReplyChatID: update.GetUserChat().GetID(), FileMessageID: msg.ID, + UserID: user.ChatID, }) } func storageCmd(ctx *ext.Context, update *ext.Update) error { - // TODO: Implement + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) + if err != nil { + logger.L.Errorf("Failed to get user: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) + return dispatcher.EndGroups + } + storages := storage.GetUserStorages(user.ChatID) + if len(storages) == 0 { + ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) + return dispatcher.EndGroups + } + + ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{ + Markup: getSetDefaultStorageMarkup(user.ChatID, storages), + }) + + return dispatcher.EndGroups +} + +func setDefaultStorage(ctx *ext.Context, update *ext.Update) error { + args := strings.Split(string(update.CallbackQuery.Data), " ") + userID, _ := strconv.Atoi(args[1]) + storageNameHash := args[2] + if userID != int(update.CallbackQuery.GetUserID()) { + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Alert: true, + Message: "你没有权限", + CacheTime: 5, + }) + return dispatcher.EndGroups + } + storageName := storageHashName[storageNameHash] + selectedStorage, err := storage.GetStorageByName(storageName) + + if err != nil { + logger.L.Errorf("failed to get storage: %s", err) + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Alert: true, + Message: "获取指定存储失败", + CacheTime: 5, + }) + return dispatcher.EndGroups + } + user, err := dao.GetUserByChatID(int64(userID)) + if err != nil { + logger.L.Errorf("Failed to get user: %s", err) + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Alert: true, + Message: "获取用户失败", + CacheTime: 5, + }) + return dispatcher.EndGroups + } + user.DefaultStorage = storageName + if err := dao.UpdateUser(user); err != nil { + logger.L.Errorf("Failed to update user: %s", err) + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Alert: true, + Message: "更新用户失败", + CacheTime: 5, + }) + return dispatcher.EndGroups + } + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()), + ID: update.CallbackQuery.GetMsgID(), + }) return dispatcher.EndGroups } @@ -272,11 +342,12 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error { ReplyMessageID: msg.ID, ReplyChatID: update.GetUserChat().GetID(), FileMessageID: update.EffectiveMessage.ID, + UserID: user.ChatID, }) } func AddToQueue(ctx *ext.Context, update *ext.Update) error { - if !slice.Contain(config.Cfg.Telegram.Admins, update.CallbackQuery.UserID) { + if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) { ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ QueryID: update.CallbackQuery.QueryID, Alert: true, @@ -339,6 +410,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { ReplyMessageID: record.ReplyMessageID, FileMessageID: record.MessageID, ReplyChatID: record.ReplyChatID, + UserID: update.EffectiveUser().GetID(), }) entityBuilder := entity.Builder{} diff --git a/bot/utils.go b/bot/utils.go index 1e556a4..4bc65e5 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -68,6 +68,26 @@ func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*t return markup, nil } +func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) *tg.ReplyInlineMarkup { + buttons := make([]tg.KeyboardButtonClass, 0) + for _, storage := range storages { + nameHash := common.HashString(storage.Name()) + storageHashName[nameHash] = storage.Name() + buttons = append(buttons, &tg.KeyboardButtonCallback{ + Text: storage.Name(), + Data: []byte(fmt.Sprintf("set_default %d %s", userChatID, nameHash)), + }) + } + markup := &tg.ReplyInlineMarkup{} + for i := 0; i < len(buttons); i += 3 { + row := tg.KeyboardButtonRow{} + row.Buttons = buttons[i:min(i+3, len(buttons))] + markup.Rows = append(markup.Rows, row) + } + return markup + +} + func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) { switch media := media.(type) { case *tg.MessageMediaDocument: diff --git a/config/user.go b/config/user.go index b4cd5f5..6754834 100644 --- a/config/user.go +++ b/config/user.go @@ -26,3 +26,24 @@ func (c *Config) GetStorageNamesByUserID(userID int64) []string { } return nil } + +func (c *Config) GetUsersID() []int64 { + var ids []int64 + for _, user := range c.Users { + ids = append(ids, user.ID) + } + return ids +} + +func (c *Config) HasStorage(userID int64, storageName string) bool { + for _, user := range c.Users { + if user.ID == userID { + if user.Blacklist { + return !slice.Contain(user.Storages, storageName) + } else { + return slice.Contain(user.Storages, storageName) + } + } + } + return false +} diff --git a/config/viper.go b/config/viper.go index 26d90b7..88a602b 100644 --- a/config/viper.go +++ b/config/viper.go @@ -101,6 +101,16 @@ func Init() error { if Cfg.Telegram.Admins != nil { fmt.Println("警告: 你正在使用旧版 Telegram 管理员配置, 该配置下的用户将可用所有存储.\ntelegram.admins 未来版本将会被废弃, 请参考新的配置文件模板, 使用 users 配置替代.") for _, admin := range Cfg.Telegram.Admins { + found := false + for _, user := range Cfg.Users { + if user.ID == admin { + found = true + break + } + } + if found { + continue + } Cfg.Users = append(Cfg.Users, userConfig{ ID: admin, Storages: []string{}, diff --git a/core/core.go b/core/core.go index 21ae3b4..51e60b2 100644 --- a/core/core.go +++ b/core/core.go @@ -41,7 +41,7 @@ func processPendingTask(task *types.Task) error { task.StoragePath = task.File.FileName } - taskStorage, err := storage.GetStorageByName(task.StorageName) + taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName) if err != nil { return err } diff --git a/storage/alist/alist.go b/storage/alist/alist.go index f349ff0..f4127d1 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -24,15 +24,6 @@ type Alist struct { config config.AlistStorageConfig } -var ConfigurableItems = []string{ - "url", - "username", - "password", - "base_path", - "token_exp", - "token", -} - func (a *Alist) Init(cfg config.StorageConfig) error { alistConfig, ok := cfg.(*config.AlistStorageConfig) if !ok { diff --git a/storage/local/local.go b/storage/local/local.go index 939f871..f6b348b 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -15,10 +15,6 @@ type Local struct { config config.LocalStorageConfig } -var ConfigurableItems = []string{ - "base_path", -} - func (l *Local) Init(cfg config.StorageConfig) error { localConfig, ok := cfg.(*config.LocalStorageConfig) if !ok { diff --git a/storage/storage.go b/storage/storage.go index 56f603c..96952ef 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/logger" "github.com/krau/SaveAny-Bot/storage/alist" "github.com/krau/SaveAny-Bot/storage/local" "github.com/krau/SaveAny-Bot/storage/webdav" @@ -44,6 +45,19 @@ func GetStorageByName(name string) (Storage, error) { return storage, nil } +// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误 +func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) { + if name == "" { + return nil, fmt.Errorf("storage name is required") + } + + if !config.Cfg.HasStorage(chatID, name) { + return nil, fmt.Errorf("storage %s not found for user %d", name, chatID) + } + + return GetStorageByName(name) +} + func GetUserStorages(chatID int64) []Storage { var storages []Storage for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) { @@ -78,14 +92,13 @@ func NewStorage(cfg config.StorageConfig) (Storage, error) { return storage, nil } -func GetStorageConfigurableItems(storageType types.StorageType) []string { - switch storageType { - case types.StorageTypeAlist: - return alist.ConfigurableItems - case types.StorageTypeLocal: - return local.ConfigurableItems - case types.StorageTypeWebdav: - return webdav.ConfigurableItems +func LoadStorages() { + logger.L.Info("Loading storages") + for _, storage := range config.Cfg.Storages { + _, err := GetStorageByName(storage.GetName()) + if err != nil { + logger.L.Errorf("Failed to load storage %s: %v", storage.GetName(), err) + } } - return nil + logger.L.Infof("Successfully loaded %d storages", len(Storages)) } diff --git a/storage/webdav/webdav.go b/storage/webdav/webdav.go index c1316e9..f7d1b10 100644 --- a/storage/webdav/webdav.go +++ b/storage/webdav/webdav.go @@ -18,8 +18,6 @@ type Webdav struct { client *gowebdav.Client } -var ConfigurableItems = []string{"url", "username", "password", "base_path"} - func (w *Webdav) Init(cfg config.StorageConfig) error { webdavConfig, ok := cfg.(*config.WebdavStorageConfig) if !ok { diff --git a/types/types.go b/types/types.go index 5b07292..9d448f2 100644 --- a/types/types.go +++ b/types/types.go @@ -43,10 +43,13 @@ type Task struct { StoragePath string StartTime time.Time - FileMessageID int - FileChatID int64 + FileMessageID int + FileChatID int64 + // to track the reply message ReplyMessageID int ReplyChatID int64 + // to track the user + UserID int64 } func (t Task) String() string {