From dfde65c28e99fcd9e765d2476bb191053c29d2da Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Tue, 18 Feb 2025 19:45:06 +0800 Subject: [PATCH] feat: (WIP) migrate storage configuration to user-specific models and remove deprecated storage loading --- bootstrap/init.go | 2 - config/viper.go | 26 ++++++++-- dao/db.go | 110 ++++++++++++++++++++++++++++++++++++++++- dao/storage.go | 29 +++++++++-- dao/user.go | 10 ++++ storage/alist/token.go | 7 ++- storage/storage.go | 19 ------- types/model.go | 26 ++++++++-- 8 files changed, 193 insertions(+), 36 deletions(-) diff --git a/bootstrap/init.go b/bootstrap/init.go index 35c65a2..8181806 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -6,7 +6,6 @@ import ( "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/logger" - "github.com/krau/SaveAny-Bot/storage" ) func InitAll() { @@ -16,6 +15,5 @@ func InitAll() { common.Init() dao.Init() - storage.LoadExistingStorages() bot.Init() } diff --git a/config/viper.go b/config/viper.go index 7706629..bd8a0de 100644 --- a/config/viper.go +++ b/config/viper.go @@ -3,9 +3,11 @@ package config import ( "fmt" "os" + "strconv" "strings" "github.com/spf13/viper" + "gorm.io/datatypes" ) type Config struct { @@ -53,6 +55,7 @@ type proxyConfig struct { /* 在配置文件中定义的存储将会为telegram.admins中的每个用户创建一个存储模型 */ +// these config will be removed in the future. type storageConfig struct { Alist AlistConfig `toml:"alist" mapstructure:"alist"` Local LocalConfig `toml:"local" mapstructure:"local"` @@ -65,8 +68,13 @@ type AlistConfig struct { Username string `toml:"username" mapstructure:"username"` Password string `toml:"password" mapstructure:"password"` Token string `toml:"token" mapstructure:"token"` - BasePath string `toml:"base_path" mapstructure:"base_path"` - TokenExp int64 `toml:"token_exp" mapstructure:"token_exp"` + BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"` + TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"` +} + +func (a *AlistConfig) ToJSON() datatypes.JSON { + tokenExp := strconv.FormatInt(a.TokenExp, 10) + return datatypes.JSON([]byte(`{"url":"` + a.URL + `","username":"` + a.Username + `","password":"` + a.Password + `","token":"` + a.Token + `","base_path":"` + a.BasePath + `","token_exp":` + tokenExp + `}`)) } type LocalConfig struct { @@ -74,6 +82,10 @@ type LocalConfig struct { BasePath string `toml:"base_path" mapstructure:"base_path"` } +func (l *LocalConfig) ToJSON() datatypes.JSON { + return datatypes.JSON([]byte(`{"base_path":"` + l.BasePath + `"}`)) +} + type WebdavConfig struct { Enable bool `toml:"enable" mapstructure:"enable"` URL string `toml:"url" mapstructure:"url"` @@ -82,6 +94,10 @@ type WebdavConfig struct { BasePath string `toml:"base_path" mapstructure:"base_path"` } +func (w *WebdavConfig) ToJSON() datatypes.JSON { + return datatypes.JSON([]byte(`{"url":"` + w.URL + `","username":"` + w.Username + `","password":"` + w.Password + `","base_path":"` + w.BasePath + `"}`)) +} + var Cfg *Config func Init() { @@ -109,9 +125,6 @@ func Init() { viper.SetDefault("db.path", "data/saveany.db") - viper.SetDefault("storage.alist.base_path", "/") - viper.SetDefault("storage.alist.token_exp", 3600) - viper.SafeWriteConfigAs("config.toml") if err := viper.ReadInConfig(); err != nil { @@ -124,6 +137,9 @@ func Init() { fmt.Println("Error unmarshalling config file, ", err) os.Exit(1) } + if Cfg.Storage != (storageConfig{}) { + fmt.Println("警告: 存储配置已经废弃, 未来版本将会移除.\n请直接使用 Bot 命令添加存储.") + } if Cfg.Workers < 1 || Cfg.Retry < 1 { fmt.Println("Invalid workers or retry value") os.Exit(1) diff --git a/dao/db.go b/dao/db.go index 392a9c5..f9c58fa 100644 --- a/dao/db.go +++ b/dao/db.go @@ -36,9 +36,117 @@ func Init() { os.Exit(1) } logger.L.Debug("Database connected") - db.AutoMigrate(&types.ReceivedFile{}, &types.User{}) + if err := db.AutoMigrate(&types.ReceivedFile{}, &types.User{}, &types.StorageModel{}); err != nil { + logger.L.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err) + } for _, admin := range config.Cfg.Telegram.Admins { CreateUser(int64(admin)) } + + logger.L.Infof("Migrating config storages to users") + storageCfg := config.Cfg.Storage + + allUsers, err := GetAllUsers() + if err != nil { + logger.L.Fatalf("Failed to get all users: %v", err) + } else { + for _, user := range allUsers { + found := false + for _, admin := range config.Cfg.Telegram.Admins { + if user.ChatID == int64(admin) { + found = true + break + } + } + if !found { + logger.L.Debugf("Deleting user %d", user.ChatID) + if err := DeleteUser(&user); err != nil { + logger.L.Fatalf("Failed to delete user %d: %v", user.ChatID, err) + } + } + } + } + // TODO: refactor this + for _, admin := range config.Cfg.Telegram.Admins { + user, err := GetUserByChatID(int64(admin)) + if err != nil { + logger.L.Fatalf("Failed to get user by chat ID %d: %v", admin, err) + continue + } + if len(user.Storages) > 0 { + logger.L.Debugf("User %d already has storages", admin) + continue + } + if storageCfg.Alist.Enable { + alistStorage := &types.StorageModel{ + Type: string(types.StorageTypeAlist), + Active: true, + Config: storageCfg.Alist.ToJSON(), + } + hash := alistStorage.GenHash() + alistStorage.Hash = hash + if storagedb, err := GetStorageByHash(hash); err == nil { + logger.L.Debugf("Alist storage already exists") + user.Storages = append(user.Storages, storagedb) + } else { + id, err := CreateStorage(alistStorage) + if err != nil { + logger.L.Fatalf("Failed to create storage: %v", err) + } else { + storagedb := &types.StorageModel{} + storagedb.ID = id + user.Storages = append(user.Storages, storagedb) + } + } + } + if storageCfg.Local.Enable { + localStorage := &types.StorageModel{ + Type: string(types.StorageTypeLocal), + Active: true, + Config: storageCfg.Local.ToJSON(), + } + hash := localStorage.GenHash() + localStorage.Hash = hash + if storagedb, err := GetStorageByHash(hash); err == nil { + logger.L.Debugf("Local storage already exists") + user.Storages = append(user.Storages, storagedb) + } else { + id, err := CreateStorage(localStorage) + if err != nil { + logger.L.Fatalf("Failed to create storage: %v", err) + } else { + storagedb := &types.StorageModel{} + storagedb.ID = id + user.Storages = append(user.Storages, storagedb) + } + } + } + if storageCfg.Webdav.Enable { + webdavStorage := &types.StorageModel{ + Type: string(types.StorageTypeWebdav), + Active: true, + Config: storageCfg.Webdav.ToJSON(), + } + hash := webdavStorage.GenHash() + webdavStorage.Hash = hash + if storagedb, err := GetStorageByHash(hash); err == nil { + logger.L.Debugf("Webdav storage already exists") + user.Storages = append(user.Storages, storagedb) + } else { + id, err := CreateStorage(webdavStorage) + if err != nil { + logger.L.Fatalf("Failed to create storage: %v", err) + } else { + storagedb := &types.StorageModel{} + storagedb.ID = id + user.Storages = append(user.Storages, storagedb) + } + } + } + if err := UpdateUser(user); err != nil { + logger.L.Fatalf("Failed to update user with storages: %v", err) + } + } + logger.L.Infof("Migration done") } diff --git a/dao/storage.go b/dao/storage.go index a202bfd..1728050 100644 --- a/dao/storage.go +++ b/dao/storage.go @@ -12,15 +12,36 @@ func GetActiveStorages() ([]types.StorageModel, error) { return storageModels, err } +func GetStorageByHash(hash string) (*types.StorageModel, error) { + var storageModel types.StorageModel + err := db.Where("hash = ?", hash).First(&storageModel).Error + return &storageModel, err +} + func GetStorageByID(id uint) (*types.StorageModel, error) { var storageModel types.StorageModel err := db.Preload("Users").First(&storageModel, id).Error return &storageModel, err } -func CreateStorage(model *types.StorageModel) error { - if model.Name == "" { - model.Name = fmt.Sprintf("%s_%d", model.Type, model.ID) +func CreateStorage(model *types.StorageModel) (uint, error) { + if model.Hash == "" { + model.Hash = model.GenHash() } - return db.Create(model).Error + getModel, err := GetStorageByHash(model.Hash) + if err == nil { + return getModel.ID, nil + } + tx := db.Create(model) + if tx.Error != nil { + return 0, tx.Error + } + if model.Name == "" { + model.Name = fmt.Sprintf("%s - %d", model.Type, model.ID) + tx = db.Save(model) + if tx.Error != nil { + return 0, tx.Error + } + } + return model.ID, nil } diff --git a/dao/user.go b/dao/user.go index 8bc7487..5819101 100644 --- a/dao/user.go +++ b/dao/user.go @@ -11,6 +11,12 @@ func CreateUser(chatID int64) error { return db.Create(&types.User{ChatID: chatID}).Error } +func GetAllUsers() ([]types.User, error) { + var users []types.User + err := db.Find(&users).Error + return users, err +} + // GetUserByUserID gets a user by their telegram user ID // // Return with active storages @@ -29,3 +35,7 @@ func GetUserWithAllStoragesByChatID(chatID int64) (*types.User, error) { func UpdateUser(user *types.User) error { return db.Save(user).Error } + +func DeleteUser(user *types.User) error { + return db.Select("Storages").Delete(user).Error +} diff --git a/storage/alist/token.go b/storage/alist/token.go index ff27199..e4a6f01 100644 --- a/storage/alist/token.go +++ b/storage/alist/token.go @@ -49,8 +49,13 @@ func (a *Alist) getToken() error { } func (a *Alist) refreshToken(cfg config.AlistConfig) { + tokenExp := cfg.TokenExp + if tokenExp <= 0 { + logger.L.Warn("Invalid token expiration time, using default value") + tokenExp = 3600 + } for { - time.Sleep(time.Duration(cfg.TokenExp) * time.Second) + time.Sleep(time.Duration(tokenExp) * time.Second) if err := a.getToken(); err != nil { logger.L.Errorf("Failed to refresh jwt token: %v", err) continue diff --git a/storage/storage.go b/storage/storage.go index f114c5b..bf14be4 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -4,7 +4,6 @@ import ( "context" "errors" - "github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/storage/alist" "github.com/krau/SaveAny-Bot/storage/local" "github.com/krau/SaveAny-Bot/storage/webdav" @@ -24,24 +23,6 @@ var ( var Storages = make(map[uint]Storage) -// LoadExistingStorages loads existing storages from the database, and initializes them -// -// Should only be called at startup -func LoadExistingStorages() error { - storageModels, err := dao.GetActiveStorages() - if err != nil { - return err - } - for _, storageModel := range storageModels { - storage, err := NewStorage(storageModel) - if err != nil { - return err - } - Storages[storageModel.ID] = storage - } - return nil -} - // Get storage from model, if it exists, otherwise create and init a new storage func GetStorageFromModel(model types.StorageModel) (Storage, error) { if model.ID == 0 { diff --git a/types/model.go b/types/model.go index 68206e1..5877f74 100644 --- a/types/model.go +++ b/types/model.go @@ -1,6 +1,9 @@ package types import ( + "crypto/md5" + "encoding/hex" + "gorm.io/datatypes" "gorm.io/gorm" ) @@ -19,7 +22,7 @@ type ReceivedFile struct { type User struct { gorm.Model - ChatID int64 `gorm:"uniqueIndex"` // Telegram user ID + ChatID int64 `gorm:"uniqueIndex;not null"` Silent bool DefaultStorageID uint Storages []*StorageModel `gorm:"many2many:user_storages;"` @@ -28,9 +31,24 @@ type User struct { type StorageModel struct { gorm.Model Type string - Name string // just for display - Desc string - Active bool Config datatypes.JSON + Active bool Users []*User `gorm:"many2many:user_storages;"` + Hash string `gorm:"uniqueIndex"` + // just for display + Name string `gorm:"not null"` + Desc string +} + +func (s *StorageModel) GenHash() string { + if s.Type == "" || s.Config == nil { + return "" + } + typeBytes := []byte(s.Type) + configBytes := s.Config + structBytes := append(typeBytes, configBytes...) + hash := md5.New() + hash.Write(structBytes) + hashBytes := hash.Sum(nil) + return hex.EncodeToString(hashBytes) }