feat: set default storage by inline keyboard

This commit is contained in:
krau
2025-02-19 12:23:12 +08:00
parent 692e970772
commit c4eb824457
11 changed files with 157 additions and 32 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
)
func InitAll() {
@@ -18,7 +19,7 @@ func InitAll() {
}
logger.InitLogger()
logger.L.Info("Starting SaveAny-Bot...")
storage.LoadStorages()
common.Init()
dao.Init()
bot.Init()

View File

@@ -35,8 +35,8 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
}
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
// dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation))
}
const noPermissionText string = `
@@ -69,7 +69,6 @@ Save Any Bot - 转存你的 Telegram 文件
/silent - 开关静默模式
/storage - 设置默认存储位置
/save [自定义文件名] - 保存文件
/path <存储类型> <路径> - 更改文件保存路径
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
@@ -196,11 +195,82 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: msg.ID,
UserID: user.ChatID,
})
}
func storageCmd(ctx *ext.Context, update *ext.Update) error {
// TODO: Implement
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{
Markup: getSetDefaultStorageMarkup(user.ChatID, storages),
})
return dispatcher.EndGroups
}
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
args := strings.Split(string(update.CallbackQuery.Data), " ")
userID, _ := strconv.Atoi(args[1])
storageNameHash := args[2]
if userID != int(update.CallbackQuery.GetUserID()) {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "你没有权限",
CacheTime: 5,
})
return dispatcher.EndGroups
}
storageName := storageHashName[storageNameHash]
selectedStorage, err := storage.GetStorageByName(storageName)
if err != nil {
logger.L.Errorf("failed to get storage: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取指定存储失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(int64(userID))
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取用户失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
user.DefaultStorage = storageName
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "更新用户失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()),
ID: update.CallbackQuery.GetMsgID(),
})
return dispatcher.EndGroups
}
@@ -272,11 +342,12 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: update.EffectiveMessage.ID,
UserID: user.ChatID,
})
}
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
if !slice.Contain(config.Cfg.Telegram.Admins, update.CallbackQuery.UserID) {
if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
@@ -339,6 +410,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,
ReplyChatID: record.ReplyChatID,
UserID: update.EffectiveUser().GetID(),
})
entityBuilder := entity.Builder{}

View File

@@ -68,6 +68,26 @@ func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*t
return markup, nil
}
func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) *tg.ReplyInlineMarkup {
buttons := make([]tg.KeyboardButtonClass, 0)
for _, storage := range storages {
nameHash := common.HashString(storage.Name())
storageHashName[nameHash] = storage.Name()
buttons = append(buttons, &tg.KeyboardButtonCallback{
Text: storage.Name(),
Data: []byte(fmt.Sprintf("set_default %d %s", userChatID, nameHash)),
})
}
markup := &tg.ReplyInlineMarkup{}
for i := 0; i < len(buttons); i += 3 {
row := tg.KeyboardButtonRow{}
row.Buttons = buttons[i:min(i+3, len(buttons))]
markup.Rows = append(markup.Rows, row)
}
return markup
}
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
switch media := media.(type) {
case *tg.MessageMediaDocument:

View File

@@ -26,3 +26,24 @@ func (c *Config) GetStorageNamesByUserID(userID int64) []string {
}
return nil
}
func (c *Config) GetUsersID() []int64 {
var ids []int64
for _, user := range c.Users {
ids = append(ids, user.ID)
}
return ids
}
func (c *Config) HasStorage(userID int64, storageName string) bool {
for _, user := range c.Users {
if user.ID == userID {
if user.Blacklist {
return !slice.Contain(user.Storages, storageName)
} else {
return slice.Contain(user.Storages, storageName)
}
}
}
return false
}

View File

@@ -101,6 +101,16 @@ func Init() error {
if Cfg.Telegram.Admins != nil {
fmt.Println("警告: 你正在使用旧版 Telegram 管理员配置, 该配置下的用户将可用所有存储.\ntelegram.admins 未来版本将会被废弃, 请参考新的配置文件模板, 使用 users 配置替代.")
for _, admin := range Cfg.Telegram.Admins {
found := false
for _, user := range Cfg.Users {
if user.ID == admin {
found = true
break
}
}
if found {
continue
}
Cfg.Users = append(Cfg.Users, userConfig{
ID: admin,
Storages: []string{},

View File

@@ -41,7 +41,7 @@ func processPendingTask(task *types.Task) error {
task.StoragePath = task.File.FileName
}
taskStorage, err := storage.GetStorageByName(task.StorageName)
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
if err != nil {
return err
}

View File

@@ -24,15 +24,6 @@ type Alist struct {
config config.AlistStorageConfig
}
var ConfigurableItems = []string{
"url",
"username",
"password",
"base_path",
"token_exp",
"token",
}
func (a *Alist) Init(cfg config.StorageConfig) error {
alistConfig, ok := cfg.(*config.AlistStorageConfig)
if !ok {

View File

@@ -15,10 +15,6 @@ type Local struct {
config config.LocalStorageConfig
}
var ConfigurableItems = []string{
"base_path",
}
func (l *Local) Init(cfg config.StorageConfig) error {
localConfig, ok := cfg.(*config.LocalStorageConfig)
if !ok {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage/alist"
"github.com/krau/SaveAny-Bot/storage/local"
"github.com/krau/SaveAny-Bot/storage/webdav"
@@ -44,6 +45,19 @@ func GetStorageByName(name string) (Storage, error) {
return storage, nil
}
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) {
if name == "" {
return nil, fmt.Errorf("storage name is required")
}
if !config.Cfg.HasStorage(chatID, name) {
return nil, fmt.Errorf("storage %s not found for user %d", name, chatID)
}
return GetStorageByName(name)
}
func GetUserStorages(chatID int64) []Storage {
var storages []Storage
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
@@ -78,14 +92,13 @@ func NewStorage(cfg config.StorageConfig) (Storage, error) {
return storage, nil
}
func GetStorageConfigurableItems(storageType types.StorageType) []string {
switch storageType {
case types.StorageTypeAlist:
return alist.ConfigurableItems
case types.StorageTypeLocal:
return local.ConfigurableItems
case types.StorageTypeWebdav:
return webdav.ConfigurableItems
func LoadStorages() {
logger.L.Info("Loading storages")
for _, storage := range config.Cfg.Storages {
_, err := GetStorageByName(storage.GetName())
if err != nil {
logger.L.Errorf("Failed to load storage %s: %v", storage.GetName(), err)
}
}
return nil
logger.L.Infof("Successfully loaded %d storages", len(Storages))
}

View File

@@ -18,8 +18,6 @@ type Webdav struct {
client *gowebdav.Client
}
var ConfigurableItems = []string{"url", "username", "password", "base_path"}
func (w *Webdav) Init(cfg config.StorageConfig) error {
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
if !ok {

View File

@@ -43,10 +43,13 @@ type Task struct {
StoragePath string
StartTime time.Time
FileMessageID int
FileChatID int64
FileMessageID int
FileChatID int64
// to track the reply message
ReplyMessageID int
ReplyChatID int64
// to track the user
UserID int64
}
func (t Task) String() string {