feat!: (WIP) switched back to using config files config storages because the conversation handling is shit

This commit is contained in:
krau
2025-02-19 11:05:30 +08:00
parent 80696c9661
commit 692e970772
24 changed files with 584 additions and 645 deletions

View File

@@ -1,6 +1,9 @@
package bootstrap
import (
"fmt"
"os"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
@@ -9,7 +12,10 @@ import (
)
func InitAll() {
config.Init()
if err := config.Init(); err != nil {
fmt.Println("Failed to init config: ", err)
os.Exit(1)
}
logger.InitLogger()
logger.L.Info("Starting SaveAny-Bot...")

View File

@@ -1,232 +0,0 @@
package bot
import (
"encoding/json"
"fmt"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func manageStorageEntry(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByChatID(update.EffectiveUser().GetID())
if err != nil {
logger.L.Errorf("Failed to get user active storages: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户存储失败"), nil)
return dispatcher.EndGroups
}
state, ok := userConversationState[user.ChatID]
if !ok {
state = &ConversationState{}
userConversationState[user.ChatID] = state
}
state.Reset()
state.InConversation = true
state.SetConversationType(ConversationTypeManageStorage)
state.SetData("status", "entry")
storagesMsg := "已添加的存储:"
if len(user.Storages) == 0 {
storagesMsg += " 无"
} else {
for i, storage := range user.Storages {
storagesMsg += fmt.Sprintf("\n%d. %s", i+1, storage.Name)
}
}
storagesMsg += "\n\n请选择操作:"
_, err = ctx.Reply(update, ext.ReplyTextString(storagesMsg), &ext.ReplyOpts{
Markup: &manageStorageKeyboardMarkup,
})
if err != nil {
logger.L.Errorf("Failed to send manage storage message: %s", err)
return dispatcher.EndGroups
}
return dispatcher.EndGroups
}
func handleManageStorageConversation(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
status := state.GetData("status").(string)
switch status {
case "entry":
return manageStorageMenu(ctx, update, state)
case "add_select_type":
return manageStorageAddSelectType(ctx, update, state)
case "selected_add_type":
return manageStorageAddSelectedType(ctx, update, state)
default:
logger.L.Errorf("Unknown manage storage status: %s", status)
}
return dispatcher.EndGroups
}
func manageStorageMenu(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
text := update.EffectiveMessage.Text
switch text {
case manageStorageButtonAdd:
return manageStorageAdd(ctx, update, state)
case manageStorageButtonDelete:
return manageStorageDelete(ctx, update)
case manageStorageButtonEdit:
return manageStorageEdit(ctx, update)
case manageStorageButtonSetDefault:
return manageStorageSetDefault(ctx, update)
default:
logger.L.Errorf("Unknown manage storage button: %s", text)
ctx.Reply(update, ext.ReplyTextString("未知操作"), nil)
return dispatcher.EndGroups
}
}
func manageStorageAdd(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
rows := make([]tg.KeyboardButtonRow, 0)
buttons := make([]tg.KeyboardButtonClass, 0)
for i, storageType := range types.StorageTypes {
buttons = append(buttons, &tg.KeyboardButton{
Text: types.StorageTypeDisplay[storageType],
})
if (i+1)%3 == 0 || i == len(types.StorageTypes)-1 {
rows = append(rows, tg.KeyboardButtonRow{
Buttons: buttons,
})
buttons = make([]tg.KeyboardButtonClass, 0)
}
}
manageStorageAddKeyboardMarkup := tg.ReplyKeyboardMarkup{
Selective: true,
Resize: true,
Rows: rows,
}
state.SetData("status", "add_select_type")
ctx.Reply(update, ext.ReplyTextString("请选择要添加的存储类型"), &ext.ReplyOpts{
Markup: &manageStorageAddKeyboardMarkup,
})
return dispatcher.ContinueGroups
}
func manageStorageAddSelectType(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
text := update.EffectiveMessage.Text
var storageType types.StorageType
for t, display := range types.StorageTypeDisplay {
if display == text {
storageType = t
break
}
}
if storageType == "" {
ctx.Reply(update, ext.ReplyTextString("未知的存储类型"), nil)
return dispatcher.EndGroups
}
state.SetData("status", "selected_add_type")
state.SetData("storage_type", storageType)
return manageStorageAddSelectedType(ctx, update, state)
}
func manageStorageAddSelectedType(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
selectedType := state.GetData("storage_type").(types.StorageType)
configItems := storage.GetStorageConfigurableItems(selectedType)
configIndexData := state.GetData("configindex")
configIndex := 0
if configIndexData == nil {
state.SetData("configindex", configIndex)
} else {
configIndex = configIndexData.(int)
if update.EffectiveMessage.Text != "" {
logger.L.Debugf("config %s: %s", configItems[configIndex-1], update.EffectiveMessage.Text)
state.SetData(configItems[configIndex-1], update.EffectiveMessage.Text)
}
}
if configIndex >= len(configItems) {
// TODO: save storage
state.SetData("status", "add_complete")
logger.L.Infof("Save storage")
return manageStorageSave(ctx, update, state)
}
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("正在配置 %s 存储...\n请提供 %s", types.StorageTypeDisplay[selectedType], configItems[configIndex])), &ext.ReplyOpts{
Markup: &tg.ReplyKeyboardForceReply{
Selective: true,
SingleUse: true,
},
})
state.SetData("configindex", configIndex+1)
return dispatcher.EndGroups
}
func manageStorageSave(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
defer func() {
if r := recover(); r != nil {
logger.L.Errorf("Failed to save storage: %s", r)
ctx.Reply(update, ext.ReplyTextString("存储配置失败"), nil)
}
state.Reset()
}()
storageType := state.GetData("storage_type").(types.StorageType)
config := make(map[string]string)
configItems := storage.GetStorageConfigurableItems(storageType)
for _, item := range configItems {
config[item] = state.GetData(item).(string)
}
configJSON, err := json.Marshal(config)
if err != nil {
logger.L.Errorf("Failed to marshal storage config: %s", err)
ctx.Reply(update, ext.ReplyTextString("存储配置失败"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.EffectiveUser().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storageModel := types.StorageModel{
Type: string(storageType),
Active: true,
Config: configJSON,
}
hash := storageModel.GenHash()
storageModel.Hash = hash
if storagedb, err := dao.GetStorageByHash(hash); err == nil {
logger.L.Debugf("Storage already exists")
user.Storages = append(user.Storages, storagedb)
} else {
if id, err := dao.CreateStorage(&storageModel); err != nil {
logger.L.Errorf("Failed to create storage: %s", err)
ctx.Reply(update, ext.ReplyTextString("存储创建失败"), nil)
return dispatcher.EndGroups
} else {
storagedb := &types.StorageModel{}
storagedb.ID = id
user.Storages = append(user.Storages, storagedb)
}
}
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user with storages: %s", err)
ctx.Reply(update, ext.ReplyTextString("用户更新失败"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("存储已添加"), &ext.ReplyOpts{
Markup: &tg.ReplyKeyboardHide{},
})
return dispatcher.EndGroups
}
func manageStorageDelete(ctx *ext.Context, update *ext.Update) error {
return dispatcher.ContinueGroups
}
func manageStorageEdit(ctx *ext.Context, update *ext.Update) error {
return dispatcher.ContinueGroups
}
func manageStorageSetDefault(ctx *ext.Context, update *ext.Update) error {
return dispatcher.ContinueGroups
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -53,7 +54,9 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if len(user.Storages) == 0 {
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
@@ -91,14 +94,14 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
})
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorageID == 0 {
return ProvideSelectMessage(ctx, update, file, int(linkChat.GetID()), messageID, replied.ID)
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file, linkChat.GetID(), messageID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageID: user.DefaultStorageID,
StorageName: user.DefaultStorage,
FileChatID: linkChat.GetID(),
FileMessageID: messageID,
ReplyMessageID: replied.ID,

View File

@@ -2,18 +2,10 @@ package bot
import (
"sync"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/logger"
)
type ConversationType string
const (
ConversationTypeManageStorage ConversationType = "manage_storage"
)
type ConversationState struct {
sync.Mutex
conversationType ConversationType
@@ -54,31 +46,30 @@ func (c *ConversationState) SetData(key string, value interface{}) {
c.data[c.conversationType][key] = value
}
var userConversationState = make(map[int64]*ConversationState)
// TODO: Implement conversation handling
// var userConversationState = make(map[int64]*ConversationState)
func handleConversation(ctx *ext.Context, update *ext.Update) error {
userID := update.EffectiveUser().GetID()
state, ok := userConversationState[userID]
if !ok {
return dispatcher.ContinueGroups
}
if update.EffectiveMessage.Text == "/cancel" {
state.Reset()
ctx.Reply(update, ext.ReplyTextString("已取消"), nil)
return dispatcher.EndGroups
}
if !state.InConversation {
return dispatcher.ContinueGroups
}
return handleConversationState(ctx, update, state)
}
// func handleConversation(ctx *ext.Context, update *ext.Update) error {
// userID := update.EffectiveUser().GetID()
// state, ok := userConversationState[userID]
// if !ok {
// return dispatcher.ContinueGroups
// }
// if update.EffectiveMessage.Text == "/cancel" {
// state.Reset()
// ctx.Reply(update, ext.ReplyTextString("已取消"), nil)
// return dispatcher.EndGroups
// }
// if !state.InConversation {
// return dispatcher.ContinueGroups
// }
// return handleConversationState(ctx, update, state)
// }
func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
switch state.conversationType {
case ConversationTypeManageStorage:
return handleManageStorageConversation(ctx, update, state)
default:
logger.L.Errorf("Unknown conversation type: %s", state.conversationType)
}
return dispatcher.EndGroups
}
// func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
// switch state.conversationType {
// default:
// logger.L.Errorf("Unknown conversation type: %s", state.conversationType)
// }
// return dispatcher.EndGroups
// }

View File

@@ -18,6 +18,7 @@ import (
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -26,7 +27,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewCommand("start", start))
dispatcher.AddHandler(handlers.NewCommand("help", help))
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
dispatcher.AddHandler(handlers.NewCommand("storage", manageStorageEntry))
dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd))
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
if err != nil {
@@ -35,7 +36,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation))
// dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation))
}
const noPermissionText string = `
@@ -120,7 +121,10 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if len(user.Storages) == 0 {
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
@@ -180,14 +184,14 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
}
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorageID == 0 {
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), msg.ID, replied.ID)
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageID: user.DefaultStorageID,
StorageName: user.DefaultStorage,
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
@@ -195,6 +199,11 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
})
}
func storageCmd(ctx *ext.Context, update *ext.Update) error {
// TODO: Implement
return dispatcher.EndGroups
}
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
@@ -211,7 +220,8 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if len(user.Storages) == 0 {
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
@@ -250,14 +260,14 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorageID == 0 {
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), update.EffectiveMessage.ID, msg.ID)
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageID: user.DefaultStorageID,
StorageName: user.DefaultStorage,
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
@@ -278,8 +288,19 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
args := strings.Split(string(update.CallbackQuery.Data), " ")
fileChatID, _ := strconv.Atoi(args[1])
fileMessageID, _ := strconv.Atoi(args[2])
storageID, _ := strconv.Atoi(args[3])
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storageID: %d", fileChatID, fileMessageID, storageID)
storageNameHash := args[3]
storageName := storageHashName[storageNameHash]
if storageName == "" {
logger.L.Errorf("Unknown storage name hash: %d", storageNameHash)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "未知存储位置",
CacheTime: 5,
})
return dispatcher.EndGroups
}
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
if err != nil {
logger.L.Errorf("Failed to get received file: %s", err)
@@ -313,7 +334,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
Ctx: ctx,
Status: types.Pending,
File: file,
StorageID: uint(storageID),
StorageName: storageName,
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,

View File

@@ -14,6 +14,7 @@ import (
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -25,39 +26,6 @@ var (
ErrNoStorages = errors.New("no available storage")
)
var (
manageStorageButtonAdd = "添加存储"
manageStorageButtonDelete = "删除存储"
manageStorageButtonEdit = "修改存储"
manageStorageButtonSetDefault = "设置默认存储"
manageStorageKeyboardMarkup = tg.ReplyKeyboardMarkup{
Selective: true,
Resize: true,
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButton{
Text: manageStorageButtonAdd,
},
&tg.KeyboardButton{
Text: manageStorageButtonDelete,
},
&tg.KeyboardButton{
Text: manageStorageButtonEdit,
},
},
},
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButton{
Text: manageStorageButtonSetDefault,
},
},
},
},
}
)
func supportedMediaFilter(m *tg.Message) (bool, error) {
if not := m.Media == nil; not {
return false, dispatcher.EndGroups
@@ -72,19 +40,23 @@ func supportedMediaFilter(m *tg.Message) (bool, error) {
}
}
// for callback data
var storageHashName = map[string]string{}
func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) {
user, err := dao.GetUserByChatID(userChatID)
if err != nil {
return nil, err
}
if len(user.Storages) < 1 {
return nil, ErrNoStorages
}
storages := storage.GetUserStorages(user.ChatID)
buttons := make([]tg.KeyboardButtonClass, 0)
for _, storage := range user.Storages {
for _, storage := range storages {
nameHash := common.HashString(storage.Name())
storageHashName[nameHash] = storage.Name()
buttons = append(buttons, &tg.KeyboardButtonCallback{
Text: storage.Name,
Data: []byte(fmt.Sprintf("add %d %d %d", fileChatID, fileMessageID, storage.ID)),
Text: storage.Name(),
Data: []byte(fmt.Sprintf("add %d %d %s", fileChatID, fileMessageID, nameHash)),
})
}
markup := &tg.ReplyInlineMarkup{}
@@ -194,7 +166,7 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
return tgMessage, nil
}
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int, fileMsgID, toEditMsgID int) error {
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int64, fileMsgID, toEditMsgID int) error {
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
@@ -207,7 +179,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
} else {
text, entities = entityBuilder.Complete()
}
markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), chatID, fileMsgID)
markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), int(chatID), fileMsgID)
if errors.Is(err, ErrNoStorages) {
logger.L.Errorf("Failed to get select storage markup: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
@@ -236,7 +208,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
}
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User, task *types.Task) error {
if user.DefaultStorageID == 0 {
if user.DefaultStorage == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "请先使用 /storage 设置默认存储位置",
ID: task.ReplyMessageID,

12
common/utils.go Normal file
View File

@@ -0,0 +1,12 @@
package common
import (
"crypto/md5"
"encoding/hex"
)
func HashString(s string) string {
hash := md5.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}

95
config/deprecated.go Normal file
View File

@@ -0,0 +1,95 @@
package config
import (
"strconv"
"github.com/krau/SaveAny-Bot/types"
"gorm.io/datatypes"
)
// for compatibility
type deprecatedStorageConfig struct {
Alist alistConfig `toml:"alist" mapstructure:"alist"`
Local localConfig `toml:"local" mapstructure:"local"`
Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"`
}
type alistConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
Token string `toml:"token" mapstructure:"token" json:"token"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
}
func (a *alistConfig) ToJSON() datatypes.JSON {
tokenExp := strconv.FormatInt(a.TokenExp, 10)
return datatypes.JSON([]byte(`{"url":"` + a.URL + `","username":"` + a.Username + `","password":"` + a.Password + `","token":"` + a.Token + `","base_path":"` + a.BasePath + `","token_exp":` + tokenExp + `}`))
}
type localConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (l *localConfig) ToJSON() datatypes.JSON {
return datatypes.JSON([]byte(`{"base_path":"` + l.BasePath + `"}`))
}
type webdavConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (w *webdavConfig) ToJSON() datatypes.JSON {
return datatypes.JSON([]byte(`{"url":"` + w.URL + `","username":"` + w.Username + `","password":"` + w.Password + `","base_path":"` + w.BasePath + `"}`))
}
func transformDeprecatedStorageConfig() {
if Cfg.DeprecatedStorage.Alist.Enable {
alistStorage := &AlistStorageConfig{
NewStorageConfig: NewStorageConfig{
Name: "Alist",
Enable: true,
Type: string(types.StorageTypeAlist),
},
URL: Cfg.DeprecatedStorage.Alist.URL,
Username: Cfg.DeprecatedStorage.Alist.Username,
Password: Cfg.DeprecatedStorage.Alist.Password,
Token: Cfg.DeprecatedStorage.Alist.Token,
BasePath: Cfg.DeprecatedStorage.Alist.BasePath,
TokenExp: Cfg.DeprecatedStorage.Alist.TokenExp,
}
Cfg.Storages = append(Cfg.Storages, alistStorage)
}
if Cfg.DeprecatedStorage.Local.Enable {
localStorage := &LocalStorageConfig{
NewStorageConfig: NewStorageConfig{
Name: "Local",
Enable: true,
Type: string(types.StorageTypeLocal),
},
BasePath: Cfg.DeprecatedStorage.Local.BasePath,
}
Cfg.Storages = append(Cfg.Storages, localStorage)
}
if Cfg.DeprecatedStorage.Webdav.Enable {
webdavStorage := &WebdavStorageConfig{
NewStorageConfig: NewStorageConfig{
Name: "Webdav",
Enable: true,
Type: string(types.StorageTypeWebdav),
},
URL: Cfg.DeprecatedStorage.Webdav.URL,
Username: Cfg.DeprecatedStorage.Webdav.Username,
Password: Cfg.DeprecatedStorage.Webdav.Password,
BasePath: Cfg.DeprecatedStorage.Webdav.BasePath,
}
Cfg.Storages = append(Cfg.Storages, webdavStorage)
}
}

104
config/storage_factory.go Normal file
View File

@@ -0,0 +1,104 @@
// storage_config.go
package config
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
)
type StorageConfig interface {
Validate() error
GetType() types.StorageType
GetName() string
}
// Base storage config
type NewStorageConfig struct {
Name string `toml:"name" mapstructure:"name" json:"name"`
Type string `toml:"type" mapstructure:"type" json:"type"`
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
RawConfig map[string]interface{} `toml:"-" mapstructure:",remain"`
}
type StorageConfigFactory func(cfg *NewStorageConfig) (StorageConfig, error)
var storageFactories = make(map[string]StorageConfigFactory)
func RegisterStorageFactory(storageType string, factory StorageConfigFactory) {
storageFactories[storageType] = factory
}
func init() {
RegisterStorageFactory(string(types.StorageTypeLocal), newLocalStorageConfig)
RegisterStorageFactory(string(types.StorageTypeAlist), newAlistStorageConfig)
RegisterStorageFactory(string(types.StorageTypeWebdav), newWebdavStorageConfig)
}
func newLocalStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
var localCfg LocalStorageConfig
localCfg.NewStorageConfig = *cfg
if err := mapstructure.Decode(cfg.RawConfig, &localCfg); err != nil {
return nil, fmt.Errorf("failed to decode local storage config: %w", err)
}
return &localCfg, nil
}
func newAlistStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
var alistCfg AlistStorageConfig
alistCfg.NewStorageConfig = *cfg
if err := mapstructure.Decode(cfg.RawConfig, &alistCfg); err != nil {
return nil, fmt.Errorf("failed to decode alist storage config: %w", err)
}
return &alistCfg, nil
}
func newWebdavStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
var webdavCfg WebdavStorageConfig
webdavCfg.NewStorageConfig = *cfg
if err := mapstructure.Decode(cfg.RawConfig, &webdavCfg); err != nil {
return nil, fmt.Errorf("failed to decode webdav storage config: %w", err)
}
return &webdavCfg, nil
}
func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) {
var baseConfigs []NewStorageConfig
if err := v.UnmarshalKey("storages", &baseConfigs); err != nil {
return nil, fmt.Errorf("failed to unmarshal storage configs: %w", err)
}
var configs []StorageConfig
for _, baseCfg := range baseConfigs {
if !baseCfg.Enable {
continue
}
factory, ok := storageFactories[baseCfg.Type]
if !ok {
return nil, fmt.Errorf("unsupported storage type: %s", baseCfg.Type)
}
cfg, err := factory(&baseCfg)
if err != nil {
return nil, fmt.Errorf("failed to create storage config for %s: %w", baseCfg.Name, err)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid storage config for %s: %w", baseCfg.Name, err)
}
configs = append(configs, cfg)
}
return configs, nil
}

106
config/storages.go Normal file
View File

@@ -0,0 +1,106 @@
package config
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
)
func (c *Config) GetStoragesByType(storageType types.StorageType) []StorageConfig {
var storages []StorageConfig
for _, storage := range c.Storages {
if storage.GetType() == storageType {
storages = append(storages, storage)
}
}
return storages
}
func (c *Config) GetStorageByName(name string) StorageConfig {
for _, storage := range c.Storages {
if storage.GetName() == name {
return storage
}
}
return nil
}
type LocalStorageConfig struct {
NewStorageConfig
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (l *LocalStorageConfig) Validate() error {
if l.BasePath == "" {
return fmt.Errorf("path is required for local storage")
}
return nil
}
func (l *LocalStorageConfig) GetType() types.StorageType {
return types.StorageTypeLocal
}
func (l *LocalStorageConfig) GetName() string {
return l.Name
}
type AlistStorageConfig struct {
NewStorageConfig
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
Token string `toml:"token" mapstructure:"token" json:"token"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
}
func (a *AlistStorageConfig) Validate() error {
if a.URL == "" {
return fmt.Errorf("url is required for alist storage")
}
if a.Token == "" && (a.Username == "" || a.Password == "") {
return fmt.Errorf("username and password or token is required for alist storage")
}
if a.BasePath == "" {
return fmt.Errorf("base_path is required for alist storage")
}
return nil
}
func (a *AlistStorageConfig) GetType() types.StorageType {
return types.StorageTypeAlist
}
func (a *AlistStorageConfig) GetName() string {
return a.Name
}
type WebdavStorageConfig struct {
NewStorageConfig
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (w *WebdavStorageConfig) Validate() error {
if w.URL == "" {
return fmt.Errorf("url is required for webdav storage")
}
if w.Username == "" || w.Password == "" {
return fmt.Errorf("username and password is required for webdav storage")
}
if w.BasePath == "" {
return fmt.Errorf("base_path is required for webdav storage")
}
return nil
}
func (w *WebdavStorageConfig) GetType() types.StorageType {
return types.StorageTypeWebdav
}
func (w *WebdavStorageConfig) GetName() string {
return w.Name
}

28
config/user.go Normal file
View File

@@ -0,0 +1,28 @@
package config
import (
"github.com/duke-git/lancet/v2/slice"
)
type userConfig struct {
ID int64 `toml:"id" mapstructure:"id" json:"id"` // telegram user id
Storages []string `toml:"storages" mapstructure:"storages" json:"storages"` // storage names
Blacklist bool `toml:"blacklist" mapstructure:"blacklist" json:"blacklist"` // 黑名单模式, storage names 中的存储将不会被使用, 默认为白名单模式
}
func (c *Config) GetStorageNamesByUserID(userID int64) []string {
for _, user := range c.Users {
if user.ID == userID {
if user.Blacklist {
allStorages := make([]string, 0, len(c.Storages))
for _, storage := range c.Storages {
allStorages = append(allStorages, storage.GetName())
}
return slice.Compact(slice.Difference(allStorages, user.Storages))
} else {
return user.Storages
}
}
}
return nil
}

View File

@@ -3,23 +3,24 @@ package config
import (
"fmt"
"os"
"strconv"
"strings"
"github.com/spf13/viper"
"gorm.io/datatypes"
)
type Config struct {
Workers int `toml:"workers" mapstructure:"workers"`
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Workers int `toml:"workers" mapstructure:"workers"`
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
Temp tempConfig `toml:"temp" mapstructure:"temp"`
Log logConfig `toml:"log" mapstructure:"log"`
DB dbConfig `toml:"db" mapstructure:"db"`
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
Storage storageConfig `toml:"storage" mapstructure:"storage"`
Temp tempConfig `toml:"temp" mapstructure:"temp"`
Log logConfig `toml:"log" mapstructure:"log"`
DB dbConfig `toml:"db" mapstructure:"db"`
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
Storages []StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
// Deprecated
DeprecatedStorage deprecatedStorageConfig `toml:"storage" mapstructure:"storage"`
}
type tempConfig struct {
@@ -38,12 +39,13 @@ type dbConfig struct {
}
type telegramConfig struct {
Token string `toml:"token" mapstructure:"token"`
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
// 白名单用户
Admins []int64 `toml:"admins" mapstructure:"admins"` // Whitelisted users
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
Token string `toml:"token" mapstructure:"token"`
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
// Deprecated
Admins []int64 `toml:"admins" mapstructure:"admins"`
}
type proxyConfig struct {
@@ -51,56 +53,9 @@ type proxyConfig struct {
URL string `toml:"url" mapstructure:"url"`
}
// pre-defined storages, for compatibility.
/*
在配置文件中定义的存储将会为telegram.admins中的每个用户创建一个存储模型
*/
// these config will be removed in the future.
type storageConfig struct {
Alist AlistConfig `toml:"alist" mapstructure:"alist"`
Local LocalConfig `toml:"local" mapstructure:"local"`
Webdav WebdavConfig `toml:"webdav" mapstructure:"webdav"`
}
type AlistConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
Token string `toml:"token" mapstructure:"token" json:"token"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
}
func (a *AlistConfig) ToJSON() datatypes.JSON {
tokenExp := strconv.FormatInt(a.TokenExp, 10)
return datatypes.JSON([]byte(`{"url":"` + a.URL + `","username":"` + a.Username + `","password":"` + a.Password + `","token":"` + a.Token + `","base_path":"` + a.BasePath + `","token_exp":` + tokenExp + `}`))
}
type LocalConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (l *LocalConfig) ToJSON() datatypes.JSON {
return datatypes.JSON([]byte(`{"base_path":"` + l.BasePath + `"}`))
}
type WebdavConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
URL string `toml:"url" mapstructure:"url" json:"url"`
Username string `toml:"username" mapstructure:"username" json:"username"`
Password string `toml:"password" mapstructure:"password" json:"password"`
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}
func (w *WebdavConfig) ToJSON() datatypes.JSON {
return datatypes.JSON([]byte(`{"url":"` + w.URL + `","username":"` + w.Username + `","password":"` + w.Password + `","base_path":"` + w.BasePath + `"}`))
}
var Cfg *Config
func Init() {
func Init() error {
viper.SetConfigName("config")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/saveany/")
@@ -125,7 +80,11 @@ func Init() {
viper.SetDefault("db.path", "data/saveany.db")
viper.SafeWriteConfigAs("config.toml")
if err := viper.SafeWriteConfigAs("config.toml"); err != nil {
if _, ok := err.(viper.ConfigFileAlreadyExistsError); !ok {
return fmt.Errorf("error saving default config: %w", err)
}
}
if err := viper.ReadInConfig(); err != nil {
fmt.Println("Error reading config file, ", err)
@@ -133,17 +92,52 @@ func Init() {
}
Cfg = &Config{}
if err := viper.Unmarshal(Cfg); err != nil {
fmt.Println("Error unmarshalling config file, ", err)
os.Exit(1)
}
if Cfg.Storage != (storageConfig{}) {
fmt.Println("警告: 存储配置已经废弃, 未来版本将会移除.\n请直接使用 Bot 命令添加存储.")
if Cfg.Telegram.Admins != nil {
fmt.Println("警告: 你正在使用旧版 Telegram 管理员配置, 该配置下的用户将可用所有存储.\ntelegram.admins 未来版本将会被废弃, 请参考新的配置文件模板, 使用 users 配置替代.")
for _, admin := range Cfg.Telegram.Admins {
Cfg.Users = append(Cfg.Users, userConfig{
ID: admin,
Storages: []string{},
Blacklist: true,
})
}
}
storagesConfig, err := LoadStorageConfigs(viper.GetViper())
if err != nil {
return fmt.Errorf("error loading storage configs: %w", err)
}
Cfg.Storages = storagesConfig
if Cfg.DeprecatedStorage != (deprecatedStorageConfig{}) {
fmt.Println("\n警告: 你正在使用旧版存储配置, 未来版本将会被废弃.\n请参考新的配置文件模板.")
transformDeprecatedStorageConfig()
}
storageNames := make(map[string]struct{})
for _, storage := range Cfg.Storages {
if _, ok := storageNames[storage.GetName()]; ok {
return fmt.Errorf("重复的存储名: %s", storage.GetName())
}
storageNames[storage.GetName()] = struct{}{}
}
fmt.Printf("已加载 %d 个存储:\n", len(Cfg.Storages))
for _, storage := range Cfg.Storages {
fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType())
}
if Cfg.Workers < 1 || Cfg.Retry < 1 {
fmt.Println("Invalid workers or retry value")
os.Exit(1)
return fmt.Errorf("workers 和 retry 必须大于 0, 当前值: workers=%d, retry=%d", Cfg.Workers, Cfg.Retry)
}
return nil
}
func Set(key string, value any) {

View File

@@ -17,7 +17,6 @@ import (
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
@@ -41,11 +40,8 @@ func processPendingTask(task *types.Task) error {
if task.StoragePath == "" {
task.StoragePath = task.File.FileName
}
storageModel, err := dao.GetStorageByID(task.StorageID)
if err != nil {
return err
}
taskStorage, err := storage.GetStorageFromModel(*storageModel)
taskStorage, err := storage.GetStorageByName(task.StorageName)
if err != nil {
return err
}
@@ -142,7 +138,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
case types.Succeeded:
logger.L.Infof("Task succeeded: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%d]: %s", task.StorageID, task.StoragePath),
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
ID: task.ReplyMessageID,
})
case types.Failed:

View File

@@ -104,8 +104,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
task.FileName(),
// TODO: use storage name instead of ID
fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
getSpeed(bytesRead, startTime),
getProgressBar(progress, barTotalCount),
progress,
@@ -115,7 +114,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath)),
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度:\n "),

108
dao/db.go
View File

@@ -36,117 +36,11 @@ func Init() {
os.Exit(1)
}
logger.L.Debug("Database connected")
if err := db.AutoMigrate(&types.ReceivedFile{}, &types.User{}, &types.StorageModel{}); err != nil {
if err := db.AutoMigrate(&types.ReceivedFile{}, &types.User{}); err != nil {
logger.L.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
}
for _, admin := range config.Cfg.Telegram.Admins {
CreateUser(int64(admin))
}
logger.L.Infof("Migrating config storages to users")
storageCfg := config.Cfg.Storage
allUsers, err := GetAllUsers()
if err != nil {
logger.L.Fatalf("Failed to get all users: %v", err)
} else {
for _, user := range allUsers {
found := false
for _, admin := range config.Cfg.Telegram.Admins {
if user.ChatID == int64(admin) {
found = true
break
}
}
if !found {
logger.L.Debugf("Deleting user %d", user.ChatID)
if err := DeleteUser(&user); err != nil {
logger.L.Fatalf("Failed to delete user %d: %v", user.ChatID, err)
}
}
}
}
// TODO: refactor this
for _, admin := range config.Cfg.Telegram.Admins {
user, err := GetUserByChatID(int64(admin))
if err != nil {
logger.L.Fatalf("Failed to get user by chat ID %d: %v", admin, err)
continue
}
if len(user.Storages) > 0 {
logger.L.Debugf("User %d already has storages", admin)
continue
}
if storageCfg.Alist.Enable {
alistStorage := &types.StorageModel{
Type: string(types.StorageTypeAlist),
Active: true,
Config: storageCfg.Alist.ToJSON(),
}
hash := alistStorage.GenHash()
alistStorage.Hash = hash
if storagedb, err := GetStorageByHash(hash); err == nil {
logger.L.Debugf("Alist storage already exists")
user.Storages = append(user.Storages, storagedb)
} else {
id, err := CreateStorage(alistStorage)
if err != nil {
logger.L.Fatalf("Failed to create storage: %v", err)
} else {
storagedb := &types.StorageModel{}
storagedb.ID = id
user.Storages = append(user.Storages, storagedb)
}
}
}
if storageCfg.Local.Enable {
localStorage := &types.StorageModel{
Type: string(types.StorageTypeLocal),
Active: true,
Config: storageCfg.Local.ToJSON(),
}
hash := localStorage.GenHash()
localStorage.Hash = hash
if storagedb, err := GetStorageByHash(hash); err == nil {
logger.L.Debugf("Local storage already exists")
user.Storages = append(user.Storages, storagedb)
} else {
id, err := CreateStorage(localStorage)
if err != nil {
logger.L.Fatalf("Failed to create storage: %v", err)
} else {
storagedb := &types.StorageModel{}
storagedb.ID = id
user.Storages = append(user.Storages, storagedb)
}
}
}
if storageCfg.Webdav.Enable {
webdavStorage := &types.StorageModel{
Type: string(types.StorageTypeWebdav),
Active: true,
Config: storageCfg.Webdav.ToJSON(),
}
hash := webdavStorage.GenHash()
webdavStorage.Hash = hash
if storagedb, err := GetStorageByHash(hash); err == nil {
logger.L.Debugf("Webdav storage already exists")
user.Storages = append(user.Storages, storagedb)
} else {
id, err := CreateStorage(webdavStorage)
if err != nil {
logger.L.Fatalf("Failed to create storage: %v", err)
} else {
storagedb := &types.StorageModel{}
storagedb.ID = id
user.Storages = append(user.Storages, storagedb)
}
}
}
if err := UpdateUser(user); err != nil {
logger.L.Fatalf("Failed to update user with storages: %v", err)
}
}
logger.L.Infof("Migration done")
}

View File

@@ -1,47 +0,0 @@
package dao
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
)
func GetActiveStorages() ([]types.StorageModel, error) {
var storageModels []types.StorageModel
err := db.Where("active = ?", true).Find(&storageModels).Error
return storageModels, err
}
func GetStorageByHash(hash string) (*types.StorageModel, error) {
var storageModel types.StorageModel
err := db.Where("hash = ?", hash).First(&storageModel).Error
return &storageModel, err
}
func GetStorageByID(id uint) (*types.StorageModel, error) {
var storageModel types.StorageModel
err := db.Preload("Users").First(&storageModel, id).Error
return &storageModel, err
}
func CreateStorage(model *types.StorageModel) (uint, error) {
if model.Hash == "" {
model.Hash = model.GenHash()
}
getModel, err := GetStorageByHash(model.Hash)
if err == nil {
return getModel.ID, nil
}
tx := db.Create(model)
if tx.Error != nil {
return 0, tx.Error
}
if model.Name == "" {
model.Name = fmt.Sprintf("%s - %d", model.Type, model.ID)
tx = db.Save(model)
if tx.Error != nil {
return 0, tx.Error
}
}
return model.ID, nil
}

View File

@@ -17,18 +17,9 @@ func GetAllUsers() ([]types.User, error) {
return users, err
}
// GetUserByUserID gets a user by their telegram user ID
//
// Return with active storages
func GetUserByChatID(chatID int64) (*types.User, error) {
var user types.User
err := db.Preload("Storages", "active = ?", true).Where("chat_id = ?", chatID).First(&user).Error
return &user, err
}
func GetUserWithAllStoragesByChatID(chatID int64) (*types.User, error) {
var user types.User
err := db.Preload("Storages").Where("chat_id = ?", chatID).First(&user).Error
err := db.Where("chat_id = ?", chatID).First(&user).Error
return &user, err
}
@@ -37,5 +28,5 @@ func UpdateUser(user *types.User) error {
}
func DeleteUser(user *types.User) error {
return db.Select("Storages").Delete(user).Error
return db.Delete(user).Error
}

View File

@@ -21,7 +21,7 @@ type Alist struct {
token string
baseURL string
loginInfo *loginRequest
config config.AlistConfig
config config.AlistStorageConfig
}
var ConfigurableItems = []string{
@@ -33,12 +33,16 @@ var ConfigurableItems = []string{
"token",
}
func (a *Alist) Init(model types.StorageModel) error {
var alistConfig config.AlistConfig
if err := json.Unmarshal([]byte(model.Config), &alistConfig); err != nil {
return fmt.Errorf("failed to unmarshal alist config: %w", err)
func (a *Alist) Init(cfg config.StorageConfig) error {
alistConfig, ok := cfg.(*config.AlistStorageConfig)
if !ok {
return fmt.Errorf("failed to cast alist config")
}
a.config = alistConfig
if err := alistConfig.Validate(); err != nil {
return err
}
a.config = *alistConfig
a.baseURL = alistConfig.URL
a.client = getHttpClient()
if alistConfig.Token != "" {
@@ -90,7 +94,7 @@ func (a *Alist) Init(model types.StorageModel) error {
}
logger.L.Debug("Logged in to Alist")
go a.refreshToken(alistConfig)
go a.refreshToken(*alistConfig)
return nil
}
@@ -98,6 +102,10 @@ func (a *Alist) Type() types.StorageType {
return types.StorageTypeAlist
}
func (a *Alist) Name() string {
return a.config.Name
}
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
file, err := os.Open(filePath)
if err != nil {

View File

@@ -48,7 +48,7 @@ func (a *Alist) getToken() error {
return nil
}
func (a *Alist) refreshToken(cfg config.AlistConfig) {
func (a *Alist) refreshToken(cfg config.AlistStorageConfig) {
tokenExp := cfg.TokenExp
if tokenExp <= 0 {
logger.L.Warn("Invalid token expiration time, using default value")

View File

@@ -2,7 +2,6 @@ package local
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
@@ -13,19 +12,22 @@ import (
)
type Local struct {
config config.LocalConfig
config config.LocalStorageConfig
}
var ConfigurableItems = []string{
"base_path",
}
func (l *Local) Init(model types.StorageModel) error {
var localConfig config.LocalConfig
if err := json.Unmarshal([]byte(model.Config), &localConfig); err != nil {
return fmt.Errorf("failed to unmarshal local config: %w", err)
func (l *Local) Init(cfg config.StorageConfig) error {
localConfig, ok := cfg.(*config.LocalStorageConfig)
if !ok {
return fmt.Errorf("failed to cast local config")
}
l.config = localConfig
if err := localConfig.Validate(); err != nil {
return err
}
l.config = *localConfig
err := os.MkdirAll(localConfig.BasePath, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create local storage directory: %w", err)
@@ -37,6 +39,10 @@ func (l *Local) Type() types.StorageType {
return types.StorageTypeLocal
}
func (l *Local) Name() string {
return l.config.Name
}
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
absPath, err := filepath.Abs(storagePath)
if err != nil {

View File

@@ -2,9 +2,9 @@ package storage
import (
"context"
"errors"
"fmt"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/storage/alist"
"github.com/krau/SaveAny-Bot/storage/local"
"github.com/krau/SaveAny-Bot/storage/webdav"
@@ -12,34 +12,50 @@ import (
)
type Storage interface {
Init(model types.StorageModel) error
Init(cfg config.StorageConfig) error
Type() types.StorageType
Name() string
JoinStoragePath(task types.Task) string
Save(cttx context.Context, localFilePath, storagePath string) error
}
var (
ErrInvalidStorageID = errors.New("invalid storage ID")
)
var Storages = make(map[string]Storage)
var Storages = make(map[uint]Storage)
// Get storage from model, if it exists, otherwise create and init a new storage
func GetStorageFromModel(model types.StorageModel) (Storage, error) {
if model.ID == 0 {
return nil, ErrInvalidStorageID
// GetStorageByName returns storage by name from cache or creates new one
func GetStorageByName(name string) (Storage, error) {
if name == "" {
return nil, fmt.Errorf("storage name is required")
}
if storage, ok := Storages[model.ID]; ok {
storage, ok := Storages[name]
if ok {
return storage, nil
}
storage, err := NewStorage(model)
cfg := config.Cfg.GetStorageByName(name)
if cfg == nil {
return nil, fmt.Errorf("storage %s not found", name)
}
storage, err := NewStorage(cfg)
if err != nil {
return nil, err
}
Storages[model.ID] = storage
Storages[name] = storage
return storage, nil
}
func GetUserStorages(chatID int64) []Storage {
var storages []Storage
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
storage, err := GetStorageByName(name)
if err != nil {
continue
}
storages = append(storages, storage)
}
return storages
}
type StorageConstructor func() Storage
var storageConstructors = map[string]StorageConstructor{
@@ -48,15 +64,15 @@ var storageConstructors = map[string]StorageConstructor{
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
}
func NewStorage(model types.StorageModel) (Storage, error) {
constructor, ok := storageConstructors[model.Type]
func NewStorage(cfg config.StorageConfig) (Storage, error) {
constructor, ok := storageConstructors[string(cfg.GetType())]
if !ok {
return nil, fmt.Errorf("unsupported storage type: %s", model.Type)
return nil, fmt.Errorf("unsupported storage type: %s", cfg.GetType())
}
storage := constructor()
if err := storage.Init(model); err != nil {
return nil, fmt.Errorf("failed to init %s storage: %w", model.Type, err)
if err := storage.Init(cfg); err != nil {
return nil, fmt.Errorf("failed to init %s storage: %w", cfg.GetName(), err)
}
return storage, nil

View File

@@ -2,7 +2,6 @@ package webdav
import (
"context"
"encoding/json"
"fmt"
"os"
"path"
@@ -15,18 +14,21 @@ import (
)
type Webdav struct {
config config.WebdavConfig
config config.WebdavStorageConfig
client *gowebdav.Client
}
var ConfigurableItems = []string{"url", "username", "password", "base_path"}
func (w *Webdav) Init(model types.StorageModel) error {
var webdavConfig config.WebdavConfig
if err := json.Unmarshal([]byte(model.Config), &webdavConfig); err != nil {
return fmt.Errorf("failed to unmarshal webdav config: %w", err)
func (w *Webdav) Init(cfg config.StorageConfig) error {
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
if !ok {
return fmt.Errorf("failed to cast webdav config")
}
w.config = webdavConfig
if err := webdavConfig.Validate(); err != nil {
return err
}
w.config = *webdavConfig
client := gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
if err := client.Connect(); err != nil {
return fmt.Errorf("failed to connect to webdav server: %w", err)
@@ -40,6 +42,10 @@ func (w *Webdav) Type() types.StorageType {
return types.StorageTypeWebdav
}
func (w *Webdav) Name() string {
return w.config.Name
}
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
if err := w.client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)

View File

@@ -1,10 +1,6 @@
package types
import (
"crypto/md5"
"encoding/hex"
"gorm.io/datatypes"
"gorm.io/gorm"
)
@@ -22,33 +18,7 @@ type ReceivedFile struct {
type User struct {
gorm.Model
ChatID int64 `gorm:"uniqueIndex;not null"`
Silent bool
DefaultStorageID uint
Storages []*StorageModel `gorm:"many2many:user_storages;"`
}
type StorageModel struct {
gorm.Model
Type string
Config datatypes.JSON
Active bool
Users []*User `gorm:"many2many:user_storages;"`
Hash string `gorm:"uniqueIndex"`
// just for display
Name string `gorm:"not null"`
Desc string
}
func (s *StorageModel) GenHash() string {
if s.Type == "" || s.Config == nil {
return ""
}
typeBytes := []byte(s.Type)
configBytes := s.Config
structBytes := append(typeBytes, configBytes...)
hash := md5.New()
hash.Write(structBytes)
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
ChatID int64 `gorm:"uniqueIndex;not null"`
Silent bool
DefaultStorage string // Default storage name
}

View File

@@ -39,7 +39,7 @@ type Task struct {
Error error
Status TaskStatus
File *File
StorageID uint
StorageName string
StoragePath string
StartTime time.Time