mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-25 17:23:50 +08:00
feat: set default storage by inline keyboard
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
func InitAll() {
|
||||
@@ -18,7 +19,7 @@ func InitAll() {
|
||||
}
|
||||
logger.InitLogger()
|
||||
logger.L.Info("Starting SaveAny-Bot...")
|
||||
|
||||
storage.LoadStorages()
|
||||
common.Init()
|
||||
dao.Init()
|
||||
bot.Init()
|
||||
|
||||
@@ -35,8 +35,8 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
||||
}
|
||||
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
|
||||
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
|
||||
// dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation))
|
||||
}
|
||||
|
||||
const noPermissionText string = `
|
||||
@@ -69,7 +69,6 @@ Save Any Bot - 转存你的 Telegram 文件
|
||||
/silent - 开关静默模式
|
||||
/storage - 设置默认存储位置
|
||||
/save [自定义文件名] - 保存文件
|
||||
/path <存储类型> <路径> - 更改文件保存路径
|
||||
|
||||
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
|
||||
|
||||
@@ -196,11 +195,82 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
ReplyMessageID: replied.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
FileMessageID: msg.ID,
|
||||
UserID: user.ChatID,
|
||||
})
|
||||
}
|
||||
|
||||
func storageCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
// TODO: Implement
|
||||
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
storages := storage.GetUserStorages(user.ChatID)
|
||||
if len(storages) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
ctx.Reply(update, ext.ReplyTextString("请选择要设为默认的存储位置"), &ext.ReplyOpts{
|
||||
Markup: getSetDefaultStorageMarkup(user.ChatID, storages),
|
||||
})
|
||||
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
|
||||
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
||||
userID, _ := strconv.Atoi(args[1])
|
||||
storageNameHash := args[2]
|
||||
if userID != int(update.CallbackQuery.GetUserID()) {
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: "你没有权限",
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
storageName := storageHashName[storageNameHash]
|
||||
selectedStorage, err := storage.GetStorageByName(storageName)
|
||||
|
||||
if err != nil {
|
||||
logger.L.Errorf("failed to get storage: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: "获取指定存储失败",
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
user, err := dao.GetUserByChatID(int64(userID))
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: "获取用户失败",
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
user.DefaultStorage = storageName
|
||||
if err := dao.UpdateUser(user); err != nil {
|
||||
logger.L.Errorf("Failed to update user: %s", err)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: "更新用户失败",
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("已将 %s (%s) 设为默认存储位置", selectedStorage.Name(), selectedStorage.Type()),
|
||||
ID: update.CallbackQuery.GetMsgID(),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
@@ -272,11 +342,12 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
ReplyMessageID: msg.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
FileMessageID: update.EffectiveMessage.ID,
|
||||
UserID: user.ChatID,
|
||||
})
|
||||
}
|
||||
|
||||
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
if !slice.Contain(config.Cfg.Telegram.Admins, update.CallbackQuery.UserID) {
|
||||
if !slice.Contain(config.Cfg.GetUsersID(), update.CallbackQuery.UserID) {
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
@@ -339,6 +410,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
ReplyMessageID: record.ReplyMessageID,
|
||||
FileMessageID: record.MessageID,
|
||||
ReplyChatID: record.ReplyChatID,
|
||||
UserID: update.EffectiveUser().GetID(),
|
||||
})
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
|
||||
20
bot/utils.go
20
bot/utils.go
@@ -68,6 +68,26 @@ func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*t
|
||||
return markup, nil
|
||||
}
|
||||
|
||||
func getSetDefaultStorageMarkup(userChatID int64, storages []storage.Storage) *tg.ReplyInlineMarkup {
|
||||
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||
for _, storage := range storages {
|
||||
nameHash := common.HashString(storage.Name())
|
||||
storageHashName[nameHash] = storage.Name()
|
||||
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||
Text: storage.Name(),
|
||||
Data: []byte(fmt.Sprintf("set_default %d %s", userChatID, nameHash)),
|
||||
})
|
||||
}
|
||||
markup := &tg.ReplyInlineMarkup{}
|
||||
for i := 0; i < len(buttons); i += 3 {
|
||||
row := tg.KeyboardButtonRow{}
|
||||
row.Buttons = buttons[i:min(i+3, len(buttons))]
|
||||
markup.Rows = append(markup.Rows, row)
|
||||
}
|
||||
return markup
|
||||
|
||||
}
|
||||
|
||||
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
|
||||
switch media := media.(type) {
|
||||
case *tg.MessageMediaDocument:
|
||||
|
||||
@@ -26,3 +26,24 @@ func (c *Config) GetStorageNamesByUserID(userID int64) []string {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) GetUsersID() []int64 {
|
||||
var ids []int64
|
||||
for _, user := range c.Users {
|
||||
ids = append(ids, user.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (c *Config) HasStorage(userID int64, storageName string) bool {
|
||||
for _, user := range c.Users {
|
||||
if user.ID == userID {
|
||||
if user.Blacklist {
|
||||
return !slice.Contain(user.Storages, storageName)
|
||||
} else {
|
||||
return slice.Contain(user.Storages, storageName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -101,6 +101,16 @@ func Init() error {
|
||||
if Cfg.Telegram.Admins != nil {
|
||||
fmt.Println("警告: 你正在使用旧版 Telegram 管理员配置, 该配置下的用户将可用所有存储.\ntelegram.admins 未来版本将会被废弃, 请参考新的配置文件模板, 使用 users 配置替代.")
|
||||
for _, admin := range Cfg.Telegram.Admins {
|
||||
found := false
|
||||
for _, user := range Cfg.Users {
|
||||
if user.ID == admin {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
Cfg.Users = append(Cfg.Users, userConfig{
|
||||
ID: admin,
|
||||
Storages: []string{},
|
||||
|
||||
@@ -41,7 +41,7 @@ func processPendingTask(task *types.Task) error {
|
||||
task.StoragePath = task.File.FileName
|
||||
}
|
||||
|
||||
taskStorage, err := storage.GetStorageByName(task.StorageName)
|
||||
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -24,15 +24,6 @@ type Alist struct {
|
||||
config config.AlistStorageConfig
|
||||
}
|
||||
|
||||
var ConfigurableItems = []string{
|
||||
"url",
|
||||
"username",
|
||||
"password",
|
||||
"base_path",
|
||||
"token_exp",
|
||||
"token",
|
||||
}
|
||||
|
||||
func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||
alistConfig, ok := cfg.(*config.AlistStorageConfig)
|
||||
if !ok {
|
||||
|
||||
@@ -15,10 +15,6 @@ type Local struct {
|
||||
config config.LocalStorageConfig
|
||||
}
|
||||
|
||||
var ConfigurableItems = []string{
|
||||
"base_path",
|
||||
}
|
||||
|
||||
func (l *Local) Init(cfg config.StorageConfig) error {
|
||||
localConfig, ok := cfg.(*config.LocalStorageConfig)
|
||||
if !ok {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage/alist"
|
||||
"github.com/krau/SaveAny-Bot/storage/local"
|
||||
"github.com/krau/SaveAny-Bot/storage/webdav"
|
||||
@@ -44,6 +45,19 @@ func GetStorageByName(name string) (Storage, error) {
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// 检查 user 是否可用指定的 storage, 若不可用则返回未找到错误
|
||||
func GetStorageByUserIDAndName(chatID int64, name string) (Storage, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("storage name is required")
|
||||
}
|
||||
|
||||
if !config.Cfg.HasStorage(chatID, name) {
|
||||
return nil, fmt.Errorf("storage %s not found for user %d", name, chatID)
|
||||
}
|
||||
|
||||
return GetStorageByName(name)
|
||||
}
|
||||
|
||||
func GetUserStorages(chatID int64) []Storage {
|
||||
var storages []Storage
|
||||
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
|
||||
@@ -78,14 +92,13 @@ func NewStorage(cfg config.StorageConfig) (Storage, error) {
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
func GetStorageConfigurableItems(storageType types.StorageType) []string {
|
||||
switch storageType {
|
||||
case types.StorageTypeAlist:
|
||||
return alist.ConfigurableItems
|
||||
case types.StorageTypeLocal:
|
||||
return local.ConfigurableItems
|
||||
case types.StorageTypeWebdav:
|
||||
return webdav.ConfigurableItems
|
||||
func LoadStorages() {
|
||||
logger.L.Info("Loading storages")
|
||||
for _, storage := range config.Cfg.Storages {
|
||||
_, err := GetStorageByName(storage.GetName())
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to load storage %s: %v", storage.GetName(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
logger.L.Infof("Successfully loaded %d storages", len(Storages))
|
||||
}
|
||||
|
||||
@@ -18,8 +18,6 @@ type Webdav struct {
|
||||
client *gowebdav.Client
|
||||
}
|
||||
|
||||
var ConfigurableItems = []string{"url", "username", "password", "base_path"}
|
||||
|
||||
func (w *Webdav) Init(cfg config.StorageConfig) error {
|
||||
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
|
||||
if !ok {
|
||||
|
||||
@@ -43,10 +43,13 @@ type Task struct {
|
||||
StoragePath string
|
||||
StartTime time.Time
|
||||
|
||||
FileMessageID int
|
||||
FileChatID int64
|
||||
FileMessageID int
|
||||
FileChatID int64
|
||||
// to track the reply message
|
||||
ReplyMessageID int
|
||||
ReplyChatID int64
|
||||
// to track the user
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (t Task) String() string {
|
||||
|
||||
Reference in New Issue
Block a user