From 968547b00548d089f9b985186bade8d93cf49662 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Tue, 18 Feb 2025 17:17:02 +0800 Subject: [PATCH] feat!: (WIP) decouple storage, users, and configuration files to support multiple users --- bootstrap/init.go | 4 +- bot/handle_link.go | 12 +++- bot/handlers.go | 132 +++++++++++------------------------- bot/utils.go | 84 +++++++++++------------ config/viper.go | 27 +++++--- core/core.go | 23 ++++--- core/utils.go | 13 ++-- dao/db.go | 2 +- dao/storage.go | 26 +++++++ dao/user.go | 19 ++++-- go.mod | 4 ++ go.sum | 10 +++ storage/alist/alist.go | 142 +++++++++------------------------------ storage/alist/token.go | 60 +++++++++++++++++ storage/alist/types.go | 44 ++++++++++++ storage/alist/utils.go | 23 +++++++ storage/local/local.go | 31 +++++++-- storage/storage.go | 123 +++++++++++++++++---------------- storage/webdav/webdav.go | 29 ++++++-- types/model.go | 26 +++++-- types/types.go | 12 ++-- 21 files changed, 474 insertions(+), 372 deletions(-) create mode 100644 dao/storage.go create mode 100644 storage/alist/token.go create mode 100644 storage/alist/types.go create mode 100644 storage/alist/utils.go diff --git a/bootstrap/init.go b/bootstrap/init.go index 0e01d6f..35c65a2 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -12,10 +12,10 @@ import ( func InitAll() { config.Init() logger.InitLogger() - logger.L.Info("Running...") + logger.L.Info("Starting SaveAny-Bot...") common.Init() - storage.Init() dao.Init() + storage.LoadExistingStorages() bot.Init() } diff --git a/bot/handle_link.go b/bot/handle_link.go index 6465745..c6576e6 100644 --- a/bot/handle_link.go +++ b/bot/handle_link.go @@ -47,9 +47,14 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { ctx.Reply(update, ext.ReplyTextString("Cannot find chat"), nil) return dispatcher.EndGroups } - user, err := dao.GetUserByUserID(update.GetUserChat().GetID()) + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) if err != nil { logger.L.Errorf("Failed to get user: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) + return dispatcher.EndGroups + } + if len(user.Storages) == 0 { + ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) return dispatcher.EndGroups } replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil) @@ -64,6 +69,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil) return dispatcher.EndGroups } + // TODO: Better file name if file.FileName == "" { logger.L.Warnf("Empty file name, use generated name") file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash()) @@ -85,14 +91,14 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { }) return dispatcher.EndGroups } - if !user.Silent { + if !user.Silent || user.DefaultStorageID == 0 { return ProvideSelectMessage(ctx, update, file, int(linkChat.GetID()), messageID, replied.ID) } return HandleSilentAddTask(ctx, update, user, &types.Task{ Ctx: ctx, Status: types.Pending, File: file, - Storage: types.StorageType(user.DefaultStorage), + StorageID: user.DefaultStorageID, FileChatID: linkChat.GetID(), FileMessageID: messageID, ReplyMessageID: replied.ID, diff --git a/bot/handlers.go b/bot/handlers.go index f4e7714..91e318b 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/duke-git/lancet/v2/slice" - "github.com/gookit/goutil/maputil" "github.com/gotd/td/telegram/message/entity" "github.com/gotd/td/telegram/message/styling" "github.com/gotd/td/tg" @@ -19,7 +18,6 @@ import ( "github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/logger" "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/types" ) @@ -41,7 +39,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) { } const noPermissionText string = ` -本 Bot 仅限个人使用. +您不在白名单中, 无法使用此 Bot. 您可以部署自己的实例: https://github.com/krau/SaveAny-Bot ` @@ -67,7 +65,7 @@ Save Any Bot - 转存你的 Telegram 文件 命令: /start - 开始使用 /help - 显示帮助 -/silent - 静默模式 +/silent - 开关静默模式 /storage - 设置默认存储位置 /save [自定义文件名] - 保存文件 /path <存储类型> <路径> - 更改文件保存路径 @@ -85,7 +83,7 @@ func help(ctx *ext.Context, update *ext.Update) error { } func silent(ctx *ext.Context, update *ext.Update) error { - user, err := dao.GetUserByUserID(update.GetUserChat().GetID()) + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) if err != nil { logger.L.Errorf("Failed to get user: %s", err) return dispatcher.EndGroups @@ -100,40 +98,17 @@ func silent(ctx *ext.Context, update *ext.Update) error { } func setDefaultStorage(ctx *ext.Context, update *ext.Update) error { - if len(storage.Storages) == 0 { - ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil) - return dispatcher.EndGroups - } - args := strings.Split(update.EffectiveMessage.Text, " ") - avaliableStorages := maputil.Keys(storage.Storages) - if len(args) < 2 { - text := []styling.StyledTextOption{ - styling.Plain("请提供存储位置名称, 可用项:"), - } - for _, name := range avaliableStorages { - text = append(text, styling.Plain("\n")) - text = append(text, styling.Code(name)) - } - text = append(text, styling.Plain("\n示例: /storage local")) - ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil) - return dispatcher.EndGroups - } - storageName := args[1] - if !slice.Contain(avaliableStorages, storageName) { - ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil) - return dispatcher.EndGroups - } - user, err := dao.GetUserByUserID(update.GetUserChat().GetID()) + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) if err != nil { - logger.L.Errorf("Failed to get user: %s", err) + logger.L.Errorf("Failed to get user active storages: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户存储失败"), nil) return dispatcher.EndGroups } - user.DefaultStorage = storageName - if err := dao.UpdateUser(user); err != nil { - logger.L.Errorf("Failed to update user: %s", err) + if len(user.Storages) == 0 { + ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) return dispatcher.EndGroups } - ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已设置默认存储位置为 %s", storageName)), nil) + // TODO: select storage return dispatcher.EndGroups } @@ -154,6 +129,17 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) + if err != nil { + logger.L.Errorf("Failed to get user: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) + return dispatcher.EndGroups + } + if len(user.Storages) == 0 { + ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) + return dispatcher.EndGroups + } + msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID) if err != nil { logger.L.Errorf("Failed to get message: %s", err) @@ -167,12 +153,6 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } - user, err := dao.GetUserByUserID(update.GetUserChat().GetID()) - if err != nil { - logger.L.Errorf("Failed to get user: %s", err) - return dispatcher.EndGroups - } - replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil) if err != nil { logger.L.Errorf("Failed to reply: %s", err) @@ -191,6 +171,8 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error { }) return dispatcher.EndGroups } + + // TODO: better file name if file.FileName == "" { file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash()) } @@ -213,14 +195,14 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error { } return dispatcher.EndGroups } - if !user.Silent { + if !user.Silent || user.DefaultStorageID == 0 { return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), msg.ID, replied.ID) } return HandleSilentAddTask(ctx, update, user, &types.Task{ Ctx: ctx, Status: types.Pending, File: file, - Storage: types.StorageType(user.DefaultStorage), + StorageID: user.DefaultStorageID, FileChatID: update.EffectiveChat().GetID(), ReplyMessageID: replied.ID, ReplyChatID: update.GetUserChat().GetID(), @@ -229,47 +211,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error { } func setPath(ctx *ext.Context, update *ext.Update) error { - if len(storage.Storages) == 0 { - ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil) - return dispatcher.EndGroups - } - if update.EffectiveMessage == nil { - logger.L.Error("No effective message") - return dispatcher.EndGroups - } - args := strings.Split(update.EffectiveMessage.Text, " ") - if len(args) < 3 { - text := []styling.StyledTextOption{ - styling.Plain("请提供存储位置名称和路径, 可用项:"), - } - for name := range storage.Storages { - text = append(text, styling.Plain("\n")) - text = append(text, styling.Code(string(name))) - } - text = append(text, styling.Plain("\n示例: /path local /path/to/save")) - ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil) - return dispatcher.EndGroups - } - storageName := args[1] - if _, ok := storage.Storages[types.StorageType(storageName)]; !ok { - ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil) - return dispatcher.EndGroups - } - path := strings.Join(args[2:], " ") - switch storageName { - case "local": - config.Set("storage.local.base_path", path) - case "webdav": - config.Set("storage.webdav.base_path", path) - case "alist": - config.Set("storage.alist.base_path", path) - } - if err := config.ReloadConfig(); err != nil { - logger.L.Errorf("Failed to reload config: %s", err) - ctx.Reply(update, ext.ReplyTextString("设置失败: "+err.Error()), nil) - return dispatcher.EndGroups - } - ctx.Reply(update, ext.ReplyTextString("设置成功"), nil) + // TODO: implement return dispatcher.EndGroups } @@ -283,9 +225,14 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } - user, err := dao.GetUserByUserID(update.GetUserChat().GetID()) + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) if err != nil { logger.L.Errorf("Failed to get user: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) + return dispatcher.EndGroups + } + if len(user.Storages) == 0 { + ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) return dispatcher.EndGroups } @@ -323,14 +270,14 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } - if !user.Silent { + if !user.Silent || user.DefaultStorageID == 0 { return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), update.EffectiveMessage.ID, msg.ID) } return HandleSilentAddTask(ctx, update, user, &types.Task{ Ctx: ctx, Status: types.Pending, File: file, - Storage: types.StorageType(user.DefaultStorage), + StorageID: user.DefaultStorageID, FileChatID: update.EffectiveChat().GetID(), ReplyMessageID: msg.ID, ReplyChatID: update.GetUserChat().GetID(), @@ -349,11 +296,11 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } args := strings.Split(string(update.CallbackQuery.Data), " ") - chatID, _ := strconv.Atoi(args[1]) - messageID, _ := strconv.Atoi(args[2]) - storageName := args[3] - logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", chatID, messageID, storageName) - record, err := dao.GetReceivedFileByChatAndMessageID(int64(chatID), messageID) + fileChatID, _ := strconv.Atoi(args[1]) + fileMessageID, _ := strconv.Atoi(args[2]) + storageID, _ := strconv.Atoi(args[3]) + logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storageID: %d", fileChatID, fileMessageID, storageID) + record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID) if err != nil { logger.L.Errorf("Failed to get received file: %s", err) ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ @@ -370,7 +317,6 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { logger.L.Errorf("Failed to update received file: %s", err) } } - file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName) if err != nil { logger.L.Errorf("Failed to get file from message: %s", err) @@ -387,7 +333,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { Ctx: ctx, Status: types.Pending, File: file, - Storage: types.StorageType(storageName), + StorageID: uint(storageID), FileChatID: record.ChatID, ReplyMessageID: record.ReplyMessageID, FileMessageID: record.MessageID, diff --git a/bot/utils.go b/bot/utils.go index 39091a7..8e0ec29 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -11,9 +11,9 @@ import ( "github.com/gotd/td/telegram/message/styling" "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/common" + "github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/logger" "github.com/krau/SaveAny-Bot/queue" - "github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/types" ) @@ -22,6 +22,7 @@ var ( ErrEmptyPhoto = errors.New("photo is empty") ErrEmptyPhotoSize = errors.New("photo size is empty") ErrEmptyPhotoSizes = errors.New("photo size slice is empty") + ErrNoStorages = errors.New("no available storage") ) func supportedMediaFilter(m *tg.Message) (bool, error) { @@ -38,49 +39,28 @@ func supportedMediaFilter(m *tg.Message) (bool, error) { } } -var StorageDisplayNames = map[string]string{ - "all": "全部", - "local": "服务器磁盘", - "alist": "Alist", - "webdav": "WebDAV", -} - -func getAddTaskMarkup(chatID, messageID int) *tg.ReplyInlineMarkup { - storageButtons := make([]tg.KeyboardButtonClass, 0) - for _, name := range storage.StorageKeys { - storageButtons = append(storageButtons, &tg.KeyboardButtonCallback{ - Text: StorageDisplayNames[string(name)], - Data: []byte(fmt.Sprintf("add %d %d %s", chatID, messageID, name)), +func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) { + user, err := dao.GetUserByChatID(userChatID) + if err != nil { + return nil, err + } + if len(user.Storages) < 1 { + return nil, ErrNoStorages + } + buttons := make([]tg.KeyboardButtonClass, 0) + for _, storage := range user.Storages { + buttons = append(buttons, &tg.KeyboardButtonCallback{ + Text: storage.Name, + Data: []byte(fmt.Sprintf("add %d %d %d", fileChatID, fileMessageID, storage.ID)), }) } - - if len(storageButtons) < 1 { - return nil - } - if len(storageButtons) == 1 { - return &tg.ReplyInlineMarkup{ - Rows: []tg.KeyboardButtonRow{ - { - Buttons: storageButtons, - }, - }, - } - } - return &tg.ReplyInlineMarkup{ - Rows: []tg.KeyboardButtonRow{ - { - Buttons: storageButtons, - }, - { - Buttons: []tg.KeyboardButtonClass{ - &tg.KeyboardButtonCallback{ - Text: "全部", - Data: []byte(fmt.Sprintf("add %d %d all", chatID, messageID)), - }, - }, - }, - }, + markup := &tg.ReplyInlineMarkup{} + for i := 0; i < len(buttons); i += 3 { + row := tg.KeyboardButtonRow{} + row.Buttons = buttons[i:min(i+3, len(buttons))] + markup.Rows = append(markup.Rows, row) } + return markup, nil } func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) { @@ -194,10 +174,26 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File } else { text, entities = entityBuilder.Complete() } - _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), chatID, fileMsgID) + if errors.Is(err, ErrNoStorages) { + logger.L.Errorf("Failed to get select storage markup: %s", err) + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: "无可用存储", + ID: toEditMsgID, + }) + return dispatcher.EndGroups + } else if err != nil { + logger.L.Errorf("Failed to get select storage markup: %s", err) + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: "无法获取存储", + ID: toEditMsgID, + }) + return dispatcher.EndGroups + } + _, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ Message: text, Entities: entities, - ReplyMarkup: getAddTaskMarkup(chatID, fileMsgID), + ReplyMarkup: markup, ID: toEditMsgID, }) if err != nil { @@ -207,7 +203,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File } func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User, task *types.Task) error { - if user.DefaultStorage == "" { + if user.DefaultStorageID == 0 { ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ Message: "请先使用 /storage 设置默认存储位置", ID: task.ReplyMessageID, diff --git a/config/viper.go b/config/viper.go index 4d5036f..7706629 100644 --- a/config/viper.go +++ b/config/viper.go @@ -36,11 +36,12 @@ type dbConfig struct { } type telegramConfig struct { - Token string `toml:"token" mapstructure:"token"` - AppID int `toml:"app_id" mapstructure:"app_id"` - AppHash string `toml:"app_hash" mapstructure:"app_hash"` - Admins []int64 `toml:"admins" mapstructure:"admins"` - Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"` + Token string `toml:"token" mapstructure:"token"` + AppID int `toml:"app_id" mapstructure:"app_id"` + AppHash string `toml:"app_hash" mapstructure:"app_hash"` + // 白名单用户 + Admins []int64 `toml:"admins" mapstructure:"admins"` // Whitelisted users + Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"` } type proxyConfig struct { @@ -48,13 +49,17 @@ type proxyConfig struct { URL string `toml:"url" mapstructure:"url"` } +// pre-defined storages, for compatibility. +/* +在配置文件中定义的存储将会为telegram.admins中的每个用户创建一个存储模型 +*/ type storageConfig struct { - Alist alistConfig `toml:"alist" mapstructure:"alist"` - Local localConfig `toml:"local" mapstructure:"local"` - Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"` + Alist AlistConfig `toml:"alist" mapstructure:"alist"` + Local LocalConfig `toml:"local" mapstructure:"local"` + Webdav WebdavConfig `toml:"webdav" mapstructure:"webdav"` } -type alistConfig struct { +type AlistConfig struct { Enable bool `toml:"enable" mapstructure:"enable"` URL string `toml:"url" mapstructure:"url"` Username string `toml:"username" mapstructure:"username"` @@ -64,12 +69,12 @@ type alistConfig struct { TokenExp int64 `toml:"token_exp" mapstructure:"token_exp"` } -type localConfig struct { +type LocalConfig struct { Enable bool `toml:"enable" mapstructure:"enable"` BasePath string `toml:"base_path" mapstructure:"base_path"` } -type webdavConfig struct { +type WebdavConfig struct { Enable bool `toml:"enable" mapstructure:"enable"` URL string `toml:"url" mapstructure:"url"` Username string `toml:"username" mapstructure:"username"` diff --git a/core/core.go b/core/core.go index 597f92c..f5e56fb 100644 --- a/core/core.go +++ b/core/core.go @@ -17,8 +17,10 @@ import ( "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/bot" "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/logger" "github.com/krau/SaveAny-Bot/queue" + "github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/types" ) @@ -39,17 +41,18 @@ func processPendingTask(task *types.Task) error { if task.StoragePath == "" { task.StoragePath = task.File.FileName } - switch task.Storage { - case types.Local: - task.StoragePath = filepath.Join(config.Cfg.Storage.Local.BasePath, task.StoragePath) - case types.Webdav: - task.StoragePath = path.Join(config.Cfg.Storage.Webdav.BasePath, task.StoragePath) - case types.Alist: - task.StoragePath = path.Join(config.Cfg.Storage.Alist.BasePath, task.StoragePath) + storageModel, err := dao.GetStorageByID(task.StorageID) + if err != nil { + return err } + taskStorage, err := storage.GetStorageFromModel(*storageModel) + if err != nil { + return err + } + task.StoragePath = taskStorage.JoinStoragePath(*task) if task.File.FileSize == 0 { - return processPhoto(task, cacheDestPath) + return processPhoto(task, taskStorage, cacheDestPath) } ctx := task.Ctx.(*ext.Context) @@ -111,7 +114,7 @@ func processPendingTask(task *types.Task) error { ID: task.ReplyMessageID, }) - return saveFileWithRetry(task, cacheDestPath) + return saveFileWithRetry(task, taskStorage, cacheDestPath) } func worker(queue *queue.TaskQueue, semaphore chan struct{}) { @@ -139,7 +142,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) { case types.Succeeded: logger.L.Infof("Task succeeded: %s", task.String()) task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ - Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.Storage, task.StoragePath), + Message: fmt.Sprintf("文件保存成功\n [%d]: %s", task.StorageID, task.StoragePath), ID: task.ReplyMessageID, }) case types.Failed: diff --git a/core/utils.go b/core/utils.go index 7a4368a..5bb2bf8 100644 --- a/core/utils.go +++ b/core/utils.go @@ -16,9 +16,9 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func saveFileWithRetry(task *types.Task, localFilePath string) error { +func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error { for i := 0; i <= config.Cfg.Retry; i++ { - if err := storage.Save(task.Storage, task.Ctx, localFilePath, task.StoragePath); err != nil { + if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil { if i == config.Cfg.Retry { return fmt.Errorf("failed to save file: %w", err) } @@ -30,7 +30,7 @@ func saveFileWithRetry(task *types.Task, localFilePath string) error { return nil } -func processPhoto(task *types.Task, cachePath string) error { +func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath string) error { res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{ Location: task.File.Location, Offset: 0, @@ -53,7 +53,7 @@ func processPhoto(task *types.Task, cachePath string) error { logger.L.Infof("Downloaded file: %s", cachePath) - return saveFileWithRetry(task, cachePath) + return saveFileWithRetry(task, taskStorage, cachePath) } func getProgressBar(progress float64, totalCount int) string { @@ -104,7 +104,8 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i entityBuilder := entity.Builder{} text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%", task.FileName(), - fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath), + // TODO: use storage name instead of ID + fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath), getSpeed(bytesRead, startTime), getProgressBar(progress, barTotalCount), progress, @@ -114,7 +115,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i styling.Plain("正在处理下载任务\n文件名: "), styling.Code(task.FileName()), styling.Plain("\n保存路径: "), - styling.Code(fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath)), + styling.Code(fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath)), styling.Plain("\n平均速度: "), styling.Bold(getSpeed(bytesRead, task.StartTime)), styling.Plain("\n当前进度:\n "), diff --git a/dao/db.go b/dao/db.go index 83903cf..392a9c5 100644 --- a/dao/db.go +++ b/dao/db.go @@ -16,7 +16,7 @@ import ( var db *gorm.DB func Init() { - if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 755); err != nil { + if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil { logger.L.Fatal("Failed to create data directory: ", err) os.Exit(1) } diff --git a/dao/storage.go b/dao/storage.go new file mode 100644 index 0000000..a202bfd --- /dev/null +++ b/dao/storage.go @@ -0,0 +1,26 @@ +package dao + +import ( + "fmt" + + "github.com/krau/SaveAny-Bot/types" +) + +func GetActiveStorages() ([]types.StorageModel, error) { + var storageModels []types.StorageModel + err := db.Where("active = ?", true).Find(&storageModels).Error + return storageModels, err +} + +func GetStorageByID(id uint) (*types.StorageModel, error) { + var storageModel types.StorageModel + err := db.Preload("Users").First(&storageModel, id).Error + return &storageModel, err +} + +func CreateStorage(model *types.StorageModel) error { + if model.Name == "" { + model.Name = fmt.Sprintf("%s_%d", model.Type, model.ID) + } + return db.Create(model).Error +} diff --git a/dao/user.go b/dao/user.go index 32d49aa..8bc7487 100644 --- a/dao/user.go +++ b/dao/user.go @@ -4,16 +4,25 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func CreateUser(userID int64) error { - if _, err := GetUserByUserID(userID); err == nil { +func CreateUser(chatID int64) error { + if _, err := GetUserByChatID(chatID); err == nil { return nil } - return db.Create(&types.User{UserID: userID}).Error + return db.Create(&types.User{ChatID: chatID}).Error } -func GetUserByUserID(userID int64) (*types.User, error) { +// GetUserByUserID gets a user by their telegram user ID +// +// Return with active storages +func GetUserByChatID(chatID int64) (*types.User, error) { var user types.User - err := db.Where("user_id = ?", userID).First(&user).Error + err := db.Preload("Storages", "active = ?", true).Where("chat_id = ?", chatID).First(&user).Error + return &user, err +} + +func GetUserWithAllStoragesByChatID(chatID int64) (*types.User, error) { + var user types.User + err := db.Preload("Storages").Where("chat_id = ?", chatID).First(&user).Error return &user, err } diff --git a/go.mod b/go.mod index 390d7f0..df94418 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/AnimeKaizoku/cacher v1.0.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -31,6 +32,7 @@ require ( github.com/go-faster/jx v1.1.0 // indirect github.com/go-faster/xor v1.0.0 // indirect github.com/go-faster/yaml v0.4.6 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/google/go-github/v30 v30.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/pprof v0.0.0-20250128161936-077ca0a936bf // indirect @@ -60,6 +62,7 @@ require ( golang.org/x/oauth2 v0.26.0 // indirect golang.org/x/tools v0.30.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gorm.io/driver/mysql v1.5.6 // indirect modernc.org/libc v1.61.13 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.8.2 // indirect @@ -97,5 +100,6 @@ require ( golang.org/x/text v0.22.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/datatypes v1.2.5 gorm.io/gorm v1.25.12 ) diff --git a/go.sum b/go.sum index 7a36da0..5f0dae5 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/AnimeKaizoku/cacher v1.0.2 h1:7Bf5qRylWb7q2Evib0OXlhG37/t7BP2HK/7IyPvSmGQ= github.com/AnimeKaizoku/cacher v1.0.2/go.mod h1:jw0de/b0K6W7Y3T9rHCMGVKUf6oG7hENNcssxYcZTCc= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= @@ -55,6 +57,9 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -265,6 +270,11 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/datatypes v1.2.5 h1:9UogU3jkydFVW1bIVVeoYsTpLRgwDVW3rHfJG6/Ek9I= +gorm.io/datatypes v1.2.5/go.mod h1:I5FUdlKpLb5PMqeMQhm30CQ6jXP8Rj89xkTeCSAaAD4= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= diff --git a/storage/alist/alist.go b/storage/alist/alist.go index 1c461b2..1dc8d06 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -1,19 +1,19 @@ package alist import ( - "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net/http" "net/url" "os" + "path" "time" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/logger" + "github.com/krau/SaveAny-Bot/types" ) type Alist struct { @@ -21,154 +21,72 @@ type Alist struct { token string baseURL string loginInfo *loginRequest + config config.AlistConfig } -var ( - ErrAlistLoginFailed = errors.New("failed to login to Alist") -) - -type loginRequest struct { - Username string `json:"username"` - Password string `json:"password"` -} - -type loginResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Data struct { - Token string `json:"token"` - } `json:"data"` -} - -type meResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Data struct { - ID int `json:"id"` - Username string `json:"username"` - } `json:"data"` -} - -type putResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Data struct { - Task struct { - ID string `json:"id"` - Name string `json:"name"` - State int `json:"state"` - Status string `json:"status"` - Progress int `json:"progress"` - Error string `json:"error"` - } `json:"task"` - } `json:"data"` -} - -func (a *Alist) getToken() error { - loginBody, err := json.Marshal(a.loginInfo) - if err != nil { - return fmt.Errorf("failed to marshal login request: %w", err) +func (a *Alist) Init(model types.StorageModel) error { + var alistConfig config.AlistConfig + if err := json.Unmarshal([]byte(model.Config), &alistConfig); err != nil { + return fmt.Errorf("failed to unmarshal alist config: %w", err) } - - req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody)) - if err != nil { - return fmt.Errorf("failed to create login request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := a.client.Do(req) - if err != nil { - return fmt.Errorf("failed to send login request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read login response: %w", err) - } - - var loginResp loginResponse - if err := json.Unmarshal(body, &loginResp); err != nil { - return fmt.Errorf("failed to unmarshal login response: %w", err) - } - - if loginResp.Code != http.StatusOK { - return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message) - } - - a.token = loginResp.Data.Token - return nil -} - -func (a *Alist) refreshToken() { - for { - time.Sleep(time.Duration(config.Cfg.Storage.Alist.TokenExp) * time.Second) - if err := a.getToken(); err != nil { - logger.L.Errorf("Failed to refresh jwt token: %v", err) - continue - } - logger.L.Info("Refreshed Alist jwt token") - } -} - -func (a *Alist) Init() { - a.baseURL = config.Cfg.Storage.Alist.URL - a.client = &http.Client{ - Timeout: 12 * time.Hour, - Transport: &http.Transport{ - TLSHandshakeTimeout: 10 * time.Second, - }, - } - if config.Cfg.Storage.Alist.Token != "" { - a.token = config.Cfg.Storage.Alist.Token + a.config = alistConfig + a.baseURL = alistConfig.URL + a.client = getHttpClient() + if alistConfig.Token != "" { + a.token = alistConfig.Token ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil) if err != nil { logger.L.Fatalf("Failed to create request: %v", err) - os.Exit(1) + return err } req.Header.Set("Authorization", a.token) resp, err := a.client.Do(req) if err != nil { logger.L.Fatalf("Failed to send request: %v", err) - os.Exit(1) + return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.L.Fatalf("Failed to get alist user info: %s", resp.Status) - os.Exit(1) + return err } body, err := io.ReadAll(resp.Body) if err != nil { logger.L.Fatalf("Failed to read response body: %v", err) - os.Exit(1) + return err } var meResp meResponse if err := json.Unmarshal(body, &meResp); err != nil { logger.L.Fatalf("Failed to unmarshal me response: %v", err) - os.Exit(1) + return err } if meResp.Code != http.StatusOK { logger.L.Fatalf("Failed to get alist user info: %s", meResp.Message) - os.Exit(1) + return err } logger.L.Debugf("Logged in Alist as %s", meResp.Data.Username) - return + return nil } a.loginInfo = &loginRequest{ - Username: config.Cfg.Storage.Alist.Username, - Password: config.Cfg.Storage.Alist.Password, + Username: alistConfig.Username, + Password: alistConfig.Password, } if err := a.getToken(); err != nil { logger.L.Fatalf("Failed to login to Alist: %v", err) - os.Exit(1) + return err } logger.L.Debug("Logged in to Alist") - go a.refreshToken() + go a.refreshToken(alistConfig) + return nil +} + +func (a *Alist) Type() types.StorageType { + return types.StorageTypeAlist } func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error { @@ -219,3 +137,7 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error { return nil } + +func (a *Alist) JoinStoragePath(task types.Task) string { + return path.Join(a.config.BasePath, task.StoragePath) +} diff --git a/storage/alist/token.go b/storage/alist/token.go new file mode 100644 index 0000000..ff27199 --- /dev/null +++ b/storage/alist/token.go @@ -0,0 +1,60 @@ +package alist + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/logger" +) + +func (a *Alist) getToken() error { + loginBody, err := json.Marshal(a.loginInfo) + if err != nil { + return fmt.Errorf("failed to marshal login request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody)) + if err != nil { + return fmt.Errorf("failed to create login request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := a.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send login request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read login response: %w", err) + } + + var loginResp loginResponse + if err := json.Unmarshal(body, &loginResp); err != nil { + return fmt.Errorf("failed to unmarshal login response: %w", err) + } + + if loginResp.Code != http.StatusOK { + return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message) + } + + a.token = loginResp.Data.Token + return nil +} + +func (a *Alist) refreshToken(cfg config.AlistConfig) { + for { + time.Sleep(time.Duration(cfg.TokenExp) * time.Second) + if err := a.getToken(); err != nil { + logger.L.Errorf("Failed to refresh jwt token: %v", err) + continue + } + logger.L.Info("Refreshed Alist jwt token") + } +} diff --git a/storage/alist/types.go b/storage/alist/types.go new file mode 100644 index 0000000..c0bec32 --- /dev/null +++ b/storage/alist/types.go @@ -0,0 +1,44 @@ +package alist + +import "errors" + +var ( + ErrAlistLoginFailed = errors.New("failed to login to Alist") +) + +type loginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type loginResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + Token string `json:"token"` + } `json:"data"` +} + +type meResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + ID int `json:"id"` + Username string `json:"username"` + } `json:"data"` +} + +type putResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + Task struct { + ID string `json:"id"` + Name string `json:"name"` + State int `json:"state"` + Status string `json:"status"` + Progress int `json:"progress"` + Error string `json:"error"` + } `json:"task"` + } `json:"data"` +} diff --git a/storage/alist/utils.go b/storage/alist/utils.go new file mode 100644 index 0000000..67f5d95 --- /dev/null +++ b/storage/alist/utils.go @@ -0,0 +1,23 @@ +package alist + +import ( + "net/http" + "time" +) + +var ( + httpClient *http.Client +) + +func getHttpClient() *http.Client { + if httpClient != nil { + return httpClient + } + httpClient = &http.Client{ + Timeout: 12 * time.Hour, + Transport: &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + }, + } + return httpClient +} diff --git a/storage/local/local.go b/storage/local/local.go index 373f3fc..1424b9f 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -2,22 +2,35 @@ package local import ( "context" + "encoding/json" + "fmt" "os" "path/filepath" "github.com/duke-git/lancet/v2/fileutil" "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/logger" + "github.com/krau/SaveAny-Bot/types" ) -type Local struct{} +type Local struct { + config config.LocalConfig +} -func (l *Local) Init() { - err := os.MkdirAll(config.Cfg.Storage.Local.BasePath, os.ModePerm) - if err != nil { - logger.L.Fatalf("Failed to create local storage directory: %s", err) - os.Exit(1) +func (l *Local) Init(model types.StorageModel) error { + var localConfig config.LocalConfig + if err := json.Unmarshal([]byte(model.Config), &localConfig); err != nil { + return fmt.Errorf("failed to unmarshal local config: %w", err) } + l.config = localConfig + err := os.MkdirAll(localConfig.BasePath, os.ModePerm) + if err != nil { + return fmt.Errorf("failed to create local storage directory: %w", err) + } + return nil +} + +func (l *Local) Type() types.StorageType { + return types.StorageTypeLocal } func (l *Local) Save(ctx context.Context, filePath, storagePath string) error { @@ -30,3 +43,7 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error { } return fileutil.CopyFile(filePath, storagePath) } + +func (l *Local) JoinStoragePath(task types.Task) string { + return filepath.Join(l.config.BasePath, task.StoragePath) +} diff --git a/storage/storage.go b/storage/storage.go index 1cfcb95..f114c5b 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -3,13 +3,8 @@ package storage import ( "context" "errors" - "path" - "path/filepath" - "sync" - "github.com/duke-git/lancet/v2/slice" - "github.com/krau/SaveAny-Bot/config" - "github.com/krau/SaveAny-Bot/logger" + "github.com/krau/SaveAny-Bot/dao" "github.com/krau/SaveAny-Bot/storage/alist" "github.com/krau/SaveAny-Bot/storage/local" "github.com/krau/SaveAny-Bot/storage/webdav" @@ -17,68 +12,72 @@ import ( ) type Storage interface { - Init() + Init(model types.StorageModel) error + Type() types.StorageType + JoinStoragePath(task types.Task) string Save(cttx context.Context, localFilePath, storagePath string) error } -var Storages = make(map[types.StorageType]Storage) -var StorageKeys = make([]types.StorageType, 0) +var ( + ErrInvalidStorageID = errors.New("invalid storage ID") +) -func Init() { - logger.L.Debug("Initializing storage...") - if config.Cfg.Storage.Alist.Enable { - Storages[types.Alist] = new(alist.Alist) - Storages[types.Alist].Init() - } - if config.Cfg.Storage.Local.Enable { - Storages[types.Local] = new(local.Local) - Storages[types.Local].Init() - } - if config.Cfg.Storage.Webdav.Enable { - Storages[types.Webdav] = new(webdav.Webdav) - Storages[types.Webdav].Init() - } +var Storages = make(map[uint]Storage) - for k := range Storages { - StorageKeys = append(StorageKeys, k) +// LoadExistingStorages loads existing storages from the database, and initializes them +// +// Should only be called at startup +func LoadExistingStorages() error { + storageModels, err := dao.GetActiveStorages() + if err != nil { + return err } - - slice.Sort(StorageKeys) - - logger.L.Debug("Storage initialized") -} - -func Save(storageType types.StorageType, ctx context.Context, filePath, storagePath string) error { - logger.L.Debugf("Saving file %s to storage: [%s] %s", filePath, storageType, storagePath) - if ctx == nil { - ctx = context.Background() - } - if storageType != types.StorageAll { - return Storages[storageType].Save(ctx, filePath, storagePath) - } - errs := make([]error, 0) - var wg sync.WaitGroup - for _, storage := range Storages { - wg.Add(1) - go func(storage Storage) { - defer wg.Done() - storageDestPath := storagePath - switch storage.(type) { - case *local.Local: - storageDestPath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath) - case *webdav.Webdav: - storageDestPath = path.Join(config.Cfg.Storage.Webdav.BasePath, storagePath) - case *alist.Alist: - storageDestPath = path.Join(config.Cfg.Storage.Alist.BasePath, storagePath) - } - if err := storage.Save(ctx, filePath, storageDestPath); err != nil { - errs = append(errs, err) - } - }(storage) - } - wg.Wait() - if len(errs) > 0 { - return errors.Join(errs...) + for _, storageModel := range storageModels { + storage, err := NewStorage(storageModel) + if err != nil { + return err + } + Storages[storageModel.ID] = storage } return nil } + +// Get storage from model, if it exists, otherwise create and init a new storage +func GetStorageFromModel(model types.StorageModel) (Storage, error) { + if model.ID == 0 { + return nil, ErrInvalidStorageID + } + if storage, ok := Storages[model.ID]; ok { + return storage, nil + } + storage, err := NewStorage(model) + if err != nil { + return nil, err + } + Storages[model.ID] = storage + return storage, nil +} + +func NewStorage(storageModel types.StorageModel) (Storage, error) { + switch storageModel.Type { + case string(types.StorageTypeAlist): + alistStorage := new(alist.Alist) + if err := alistStorage.Init(storageModel); err != nil { + return nil, err + } + return alistStorage, nil + case string(types.StorageTypeLocal): + localStorage := new(local.Local) + if err := localStorage.Init(storageModel); err != nil { + return nil, err + } + return localStorage, nil + case string(types.StorageTypeWebdav): + webdavStorage := new(webdav.Webdav) + if err := webdavStorage.Init(storageModel); err != nil { + return nil, err + } + return webdavStorage, nil + } + return nil, nil +} diff --git a/storage/webdav/webdav.go b/storage/webdav/webdav.go index f0c3a8b..b0dcc99 100644 --- a/storage/webdav/webdav.go +++ b/storage/webdav/webdav.go @@ -2,29 +2,42 @@ package webdav import ( "context" + "encoding/json" + "fmt" "os" "path" "time" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/logger" + "github.com/krau/SaveAny-Bot/types" "github.com/studio-b12/gowebdav" ) -type Webdav struct{} +type Webdav struct { + config config.WebdavConfig +} var ( Client *gowebdav.Client ) -func (w *Webdav) Init() { - webdavConfig := config.Cfg.Storage.Webdav +func (w *Webdav) Init(model types.StorageModel) error { + var webdavConfig config.WebdavConfig + if err := json.Unmarshal([]byte(model.Config), &webdavConfig); err != nil { + return fmt.Errorf("failed to unmarshal webdav config: %w", err) + } + w.config = webdavConfig Client = gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password) if err := Client.Connect(); err != nil { - logger.L.Fatalf("Failed to connect to webdav server: %v", err) - os.Exit(1) + return fmt.Errorf("failed to connect to webdav server: %w", err) } - Client.SetTimeout(24 * time.Hour) + Client.SetTimeout(12 * time.Hour) + return nil +} + +func (w *Webdav) Type() types.StorageType { + return types.StorageTypeWebdav } func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error { @@ -45,3 +58,7 @@ func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error { } return nil } + +func (w *Webdav) JoinStoragePath(task types.Task) string { + return path.Join(w.config.BasePath, task.StoragePath) +} diff --git a/types/model.go b/types/model.go index 3c493f8..68206e1 100644 --- a/types/model.go +++ b/types/model.go @@ -1,14 +1,17 @@ package types import ( + "gorm.io/datatypes" "gorm.io/gorm" ) type ReceivedFile struct { gorm.Model - Processing bool - ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"` - MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"` + Processing bool + // Which chat the file is from + ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"` + // Which message the file is from + MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"` ReplyMessageID int ReplyChatID int64 FileName string @@ -16,7 +19,18 @@ type ReceivedFile struct { type User struct { gorm.Model - UserID int64 `gorm:"uniqueIndex"` - Silent bool - DefaultStorage string + ChatID int64 `gorm:"uniqueIndex"` // Telegram user ID + Silent bool + DefaultStorageID uint + Storages []*StorageModel `gorm:"many2many:user_storages;"` +} + +type StorageModel struct { + gorm.Model + Type string + Name string // just for display + Desc string + Active bool + Config datatypes.JSON + Users []*User `gorm:"many2many:user_storages;"` } diff --git a/types/types.go b/types/types.go index e7c522e..273259c 100644 --- a/types/types.go +++ b/types/types.go @@ -22,20 +22,20 @@ var ( type StorageType string var ( - StorageAll StorageType = "all" - Local StorageType = "local" - Webdav StorageType = "webdav" - Alist StorageType = "alist" + StorageAll StorageType = "all" + StorageTypeLocal StorageType = "local" + StorageTypeWebdav StorageType = "webdav" + StorageTypeAlist StorageType = "alist" ) -var StorageTypes = []StorageType{Local, Alist, Webdav, StorageAll} +var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav, StorageAll} type Task struct { Ctx context.Context Error error Status TaskStatus File *File - Storage StorageType + StorageID uint StoragePath string StartTime time.Time