feat!: (WIP) decouple storage, users, and configuration files to support multiple users

This commit is contained in:
krau
2025-02-18 17:17:02 +08:00
parent 9367419156
commit 968547b005
21 changed files with 474 additions and 372 deletions

View File

@@ -12,10 +12,10 @@ import (
func InitAll() {
config.Init()
logger.InitLogger()
logger.L.Info("Running...")
logger.L.Info("Starting SaveAny-Bot...")
common.Init()
storage.Init()
dao.Init()
storage.LoadExistingStorages()
bot.Init()
}

View File

@@ -47,9 +47,14 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString("Cannot find chat"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if len(user.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
@@ -64,6 +69,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
// TODO: Better file name
if file.FileName == "" {
logger.L.Warnf("Empty file name, use generated name")
file.FileName = fmt.Sprintf("%d_%d_%s", linkChat.GetID(), messageID, file.Hash())
@@ -85,14 +91,14 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
})
return dispatcher.EndGroups
}
if !user.Silent {
if !user.Silent || user.DefaultStorageID == 0 {
return ProvideSelectMessage(ctx, update, file, int(linkChat.GetID()), messageID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
StorageID: user.DefaultStorageID,
FileChatID: linkChat.GetID(),
FileMessageID: messageID,
ReplyMessageID: replied.ID,

View File

@@ -6,7 +6,6 @@ import (
"strings"
"github.com/duke-git/lancet/v2/slice"
"github.com/gookit/goutil/maputil"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
@@ -19,7 +18,6 @@ import (
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -41,7 +39,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
}
const noPermissionText string = `
本 Bot 仅限个人使用.
您不在白名单中, 无法使用此 Bot.
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
`
@@ -67,7 +65,7 @@ Save Any Bot - 转存你的 Telegram 文件
命令:
/start - 开始使用
/help - 显示帮助
/silent - 静默模式
/silent - 开关静默模式
/storage - 设置默认存储位置
/save [自定义文件名] - 保存文件
/path <存储类型> <路径> - 更改文件保存路径
@@ -85,7 +83,7 @@ func help(ctx *ext.Context, update *ext.Update) error {
}
func silent(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
@@ -100,40 +98,17 @@ func silent(ctx *ext.Context, update *ext.Update) error {
}
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
if len(storage.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil)
return dispatcher.EndGroups
}
args := strings.Split(update.EffectiveMessage.Text, " ")
avaliableStorages := maputil.Keys(storage.Storages)
if len(args) < 2 {
text := []styling.StyledTextOption{
styling.Plain("请提供存储位置名称, 可用项:"),
}
for _, name := range avaliableStorages {
text = append(text, styling.Plain("\n"))
text = append(text, styling.Code(name))
}
text = append(text, styling.Plain("\n示例: /storage local"))
ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil)
return dispatcher.EndGroups
}
storageName := args[1]
if !slice.Contain(avaliableStorages, storageName) {
ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
logger.L.Errorf("Failed to get user active storages: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户存储失败"), nil)
return dispatcher.EndGroups
}
user.DefaultStorage = storageName
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user: %s", err)
if len(user.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("已设置默认存储位置为 %s", storageName)), nil)
// TODO: select storage
return dispatcher.EndGroups
}
@@ -154,6 +129,17 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if len(user.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
if err != nil {
logger.L.Errorf("Failed to get message: %s", err)
@@ -167,12 +153,6 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件信息..."), nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
@@ -191,6 +171,8 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
})
return dispatcher.EndGroups
}
// TODO: better file name
if file.FileName == "" {
file.FileName = fmt.Sprintf("%d_%d_%s", update.EffectiveChat().GetID(), replyToMsgID, file.Hash())
}
@@ -213,14 +195,14 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
}
return dispatcher.EndGroups
}
if !user.Silent {
if !user.Silent || user.DefaultStorageID == 0 {
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), msg.ID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
StorageID: user.DefaultStorageID,
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
@@ -229,47 +211,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
}
func setPath(ctx *ext.Context, update *ext.Update) error {
if len(storage.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("未配置存储"), nil)
return dispatcher.EndGroups
}
if update.EffectiveMessage == nil {
logger.L.Error("No effective message")
return dispatcher.EndGroups
}
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 3 {
text := []styling.StyledTextOption{
styling.Plain("请提供存储位置名称和路径, 可用项:"),
}
for name := range storage.Storages {
text = append(text, styling.Plain("\n"))
text = append(text, styling.Code(string(name)))
}
text = append(text, styling.Plain("\n示例: /path local /path/to/save"))
ctx.Reply(update, ext.ReplyTextStyledTextArray(text), nil)
return dispatcher.EndGroups
}
storageName := args[1]
if _, ok := storage.Storages[types.StorageType(storageName)]; !ok {
ctx.Reply(update, ext.ReplyTextString("存储位置不存在"), nil)
return dispatcher.EndGroups
}
path := strings.Join(args[2:], " ")
switch storageName {
case "local":
config.Set("storage.local.base_path", path)
case "webdav":
config.Set("storage.webdav.base_path", path)
case "alist":
config.Set("storage.alist.base_path", path)
}
if err := config.ReloadConfig(); err != nil {
logger.L.Errorf("Failed to reload config: %s", err)
ctx.Reply(update, ext.ReplyTextString("设置失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("设置成功"), nil)
// TODO: implement
return dispatcher.EndGroups
}
@@ -283,9 +225,14 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
if len(user.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
@@ -323,14 +270,14 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if !user.Silent {
if !user.Silent || user.DefaultStorageID == 0 {
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), update.EffectiveMessage.ID, msg.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
StorageID: user.DefaultStorageID,
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
@@ -349,11 +296,11 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
args := strings.Split(string(update.CallbackQuery.Data), " ")
chatID, _ := strconv.Atoi(args[1])
messageID, _ := strconv.Atoi(args[2])
storageName := args[3]
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", chatID, messageID, storageName)
record, err := dao.GetReceivedFileByChatAndMessageID(int64(chatID), messageID)
fileChatID, _ := strconv.Atoi(args[1])
fileMessageID, _ := strconv.Atoi(args[2])
storageID, _ := strconv.Atoi(args[3])
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storageID: %d", fileChatID, fileMessageID, storageID)
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
if err != nil {
logger.L.Errorf("Failed to get received file: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
@@ -370,7 +317,6 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
logger.L.Errorf("Failed to update received file: %s", err)
}
}
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
@@ -387,7 +333,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(storageName),
StorageID: uint(storageID),
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,

View File

@@ -11,9 +11,9 @@ import (
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -22,6 +22,7 @@ var (
ErrEmptyPhoto = errors.New("photo is empty")
ErrEmptyPhotoSize = errors.New("photo size is empty")
ErrEmptyPhotoSizes = errors.New("photo size slice is empty")
ErrNoStorages = errors.New("no available storage")
)
func supportedMediaFilter(m *tg.Message) (bool, error) {
@@ -38,49 +39,28 @@ func supportedMediaFilter(m *tg.Message) (bool, error) {
}
}
var StorageDisplayNames = map[string]string{
"all": "全部",
"local": "服务器磁盘",
"alist": "Alist",
"webdav": "WebDAV",
}
func getAddTaskMarkup(chatID, messageID int) *tg.ReplyInlineMarkup {
storageButtons := make([]tg.KeyboardButtonClass, 0)
for _, name := range storage.StorageKeys {
storageButtons = append(storageButtons, &tg.KeyboardButtonCallback{
Text: StorageDisplayNames[string(name)],
Data: []byte(fmt.Sprintf("add %d %d %s", chatID, messageID, name)),
func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) {
user, err := dao.GetUserByChatID(userChatID)
if err != nil {
return nil, err
}
if len(user.Storages) < 1 {
return nil, ErrNoStorages
}
buttons := make([]tg.KeyboardButtonClass, 0)
for _, storage := range user.Storages {
buttons = append(buttons, &tg.KeyboardButtonCallback{
Text: storage.Name,
Data: []byte(fmt.Sprintf("add %d %d %d", fileChatID, fileMessageID, storage.ID)),
})
}
if len(storageButtons) < 1 {
return nil
}
if len(storageButtons) == 1 {
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: storageButtons,
},
},
}
}
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: storageButtons,
},
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButtonCallback{
Text: "全部",
Data: []byte(fmt.Sprintf("add %d %d all", chatID, messageID)),
},
},
},
},
markup := &tg.ReplyInlineMarkup{}
for i := 0; i < len(buttons); i += 3 {
row := tg.KeyboardButtonRow{}
row.Buttons = buttons[i:min(i+3, len(buttons))]
markup.Rows = append(markup.Rows, row)
}
return markup, nil
}
func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.File, error) {
@@ -194,10 +174,26 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
} else {
text, entities = entityBuilder.Complete()
}
_, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), chatID, fileMsgID)
if errors.Is(err, ErrNoStorages) {
logger.L.Errorf("Failed to get select storage markup: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无可用存储",
ID: toEditMsgID,
})
return dispatcher.EndGroups
} else if err != nil {
logger.L.Errorf("Failed to get select storage markup: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取存储",
ID: toEditMsgID,
})
return dispatcher.EndGroups
}
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ReplyMarkup: getAddTaskMarkup(chatID, fileMsgID),
ReplyMarkup: markup,
ID: toEditMsgID,
})
if err != nil {
@@ -207,7 +203,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
}
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User, task *types.Task) error {
if user.DefaultStorage == "" {
if user.DefaultStorageID == 0 {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "请先使用 /storage 设置默认存储位置",
ID: task.ReplyMessageID,

View File

@@ -36,11 +36,12 @@ type dbConfig struct {
}
type telegramConfig struct {
Token string `toml:"token" mapstructure:"token"`
AppID int `toml:"app_id" mapstructure:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash"`
Admins []int64 `toml:"admins" mapstructure:"admins"`
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
Token string `toml:"token" mapstructure:"token"`
AppID int `toml:"app_id" mapstructure:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash"`
// 白名单用户
Admins []int64 `toml:"admins" mapstructure:"admins"` // Whitelisted users
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
}
type proxyConfig struct {
@@ -48,13 +49,17 @@ type proxyConfig struct {
URL string `toml:"url" mapstructure:"url"`
}
// pre-defined storages, for compatibility.
/*
在配置文件中定义的存储将会为telegram.admins中的每个用户创建一个存储模型
*/
type storageConfig struct {
Alist alistConfig `toml:"alist" mapstructure:"alist"`
Local localConfig `toml:"local" mapstructure:"local"`
Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"`
Alist AlistConfig `toml:"alist" mapstructure:"alist"`
Local LocalConfig `toml:"local" mapstructure:"local"`
Webdav WebdavConfig `toml:"webdav" mapstructure:"webdav"`
}
type alistConfig struct {
type AlistConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
URL string `toml:"url" mapstructure:"url"`
Username string `toml:"username" mapstructure:"username"`
@@ -64,12 +69,12 @@ type alistConfig struct {
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp"`
}
type localConfig struct {
type LocalConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
BasePath string `toml:"base_path" mapstructure:"base_path"`
}
type webdavConfig struct {
type WebdavConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
URL string `toml:"url" mapstructure:"url"`
Username string `toml:"username" mapstructure:"username"`

View File

@@ -17,8 +17,10 @@ import (
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -39,17 +41,18 @@ func processPendingTask(task *types.Task) error {
if task.StoragePath == "" {
task.StoragePath = task.File.FileName
}
switch task.Storage {
case types.Local:
task.StoragePath = filepath.Join(config.Cfg.Storage.Local.BasePath, task.StoragePath)
case types.Webdav:
task.StoragePath = path.Join(config.Cfg.Storage.Webdav.BasePath, task.StoragePath)
case types.Alist:
task.StoragePath = path.Join(config.Cfg.Storage.Alist.BasePath, task.StoragePath)
storageModel, err := dao.GetStorageByID(task.StorageID)
if err != nil {
return err
}
taskStorage, err := storage.GetStorageFromModel(*storageModel)
if err != nil {
return err
}
task.StoragePath = taskStorage.JoinStoragePath(*task)
if task.File.FileSize == 0 {
return processPhoto(task, cacheDestPath)
return processPhoto(task, taskStorage, cacheDestPath)
}
ctx := task.Ctx.(*ext.Context)
@@ -111,7 +114,7 @@ func processPendingTask(task *types.Task) error {
ID: task.ReplyMessageID,
})
return saveFileWithRetry(task, cacheDestPath)
return saveFileWithRetry(task, taskStorage, cacheDestPath)
}
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
@@ -139,7 +142,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
case types.Succeeded:
logger.L.Infof("Task succeeded: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.Storage, task.StoragePath),
Message: fmt.Sprintf("文件保存成功\n [%d]: %s", task.StorageID, task.StoragePath),
ID: task.ReplyMessageID,
})
case types.Failed:

View File

@@ -16,9 +16,9 @@ import (
"github.com/krau/SaveAny-Bot/types"
)
func saveFileWithRetry(task *types.Task, localFilePath string) error {
func saveFileWithRetry(task *types.Task, taskStorage storage.Storage, localFilePath string) error {
for i := 0; i <= config.Cfg.Retry; i++ {
if err := storage.Save(task.Storage, task.Ctx, localFilePath, task.StoragePath); err != nil {
if err := taskStorage.Save(task.Ctx, localFilePath, task.StoragePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
@@ -30,7 +30,7 @@ func saveFileWithRetry(task *types.Task, localFilePath string) error {
return nil
}
func processPhoto(task *types.Task, cachePath string) error {
func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath string) error {
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
Location: task.File.Location,
Offset: 0,
@@ -53,7 +53,7 @@ func processPhoto(task *types.Task, cachePath string) error {
logger.L.Infof("Downloaded file: %s", cachePath)
return saveFileWithRetry(task, cachePath)
return saveFileWithRetry(task, taskStorage, cachePath)
}
func getProgressBar(progress float64, totalCount int) string {
@@ -104,7 +104,8 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath),
// TODO: use storage name instead of ID
fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath),
getSpeed(bytesRead, startTime),
getProgressBar(progress, barTotalCount),
progress,
@@ -114,7 +115,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage, task.StoragePath)),
styling.Code(fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度:\n "),

View File

@@ -16,7 +16,7 @@ import (
var db *gorm.DB
func Init() {
if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 755); err != nil {
if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil {
logger.L.Fatal("Failed to create data directory: ", err)
os.Exit(1)
}

26
dao/storage.go Normal file
View File

@@ -0,0 +1,26 @@
package dao
import (
"fmt"
"github.com/krau/SaveAny-Bot/types"
)
func GetActiveStorages() ([]types.StorageModel, error) {
var storageModels []types.StorageModel
err := db.Where("active = ?", true).Find(&storageModels).Error
return storageModels, err
}
func GetStorageByID(id uint) (*types.StorageModel, error) {
var storageModel types.StorageModel
err := db.Preload("Users").First(&storageModel, id).Error
return &storageModel, err
}
func CreateStorage(model *types.StorageModel) error {
if model.Name == "" {
model.Name = fmt.Sprintf("%s_%d", model.Type, model.ID)
}
return db.Create(model).Error
}

View File

@@ -4,16 +4,25 @@ import (
"github.com/krau/SaveAny-Bot/types"
)
func CreateUser(userID int64) error {
if _, err := GetUserByUserID(userID); err == nil {
func CreateUser(chatID int64) error {
if _, err := GetUserByChatID(chatID); err == nil {
return nil
}
return db.Create(&types.User{UserID: userID}).Error
return db.Create(&types.User{ChatID: chatID}).Error
}
func GetUserByUserID(userID int64) (*types.User, error) {
// GetUserByUserID gets a user by their telegram user ID
//
// Return with active storages
func GetUserByChatID(chatID int64) (*types.User, error) {
var user types.User
err := db.Where("user_id = ?", userID).First(&user).Error
err := db.Preload("Storages", "active = ?", true).Where("chat_id = ?", chatID).First(&user).Error
return &user, err
}
func GetUserWithAllStoragesByChatID(chatID int64) (*types.User, error) {
var user types.User
err := db.Preload("Storages").Where("chat_id = ?", chatID).First(&user).Error
return &user, err
}

4
go.mod
View File

@@ -17,6 +17,7 @@ require (
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/AnimeKaizoku/cacher v1.0.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -31,6 +32,7 @@ require (
github.com/go-faster/jx v1.1.0 // indirect
github.com/go-faster/xor v1.0.0 // indirect
github.com/go-faster/yaml v0.4.6 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/google/go-github/v30 v30.1.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/pprof v0.0.0-20250128161936-077ca0a936bf // indirect
@@ -60,6 +62,7 @@ require (
golang.org/x/oauth2 v0.26.0 // indirect
golang.org/x/tools v0.30.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gorm.io/driver/mysql v1.5.6 // indirect
modernc.org/libc v1.61.13 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.8.2 // indirect
@@ -97,5 +100,6 @@ require (
golang.org/x/text v0.22.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/datatypes v1.2.5
gorm.io/gorm v1.25.12
)

10
go.sum
View File

@@ -1,3 +1,5 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/AnimeKaizoku/cacher v1.0.2 h1:7Bf5qRylWb7q2Evib0OXlhG37/t7BP2HK/7IyPvSmGQ=
github.com/AnimeKaizoku/cacher v1.0.2/go.mod h1:jw0de/b0K6W7Y3T9rHCMGVKUf6oG7hENNcssxYcZTCc=
github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
@@ -55,6 +57,9 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -265,6 +270,11 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.5 h1:9UogU3jkydFVW1bIVVeoYsTpLRgwDVW3rHfJG6/Ek9I=
gorm.io/datatypes v1.2.5/go.mod h1:I5FUdlKpLb5PMqeMQhm30CQ6jXP8Rj89xkTeCSAaAD4=
gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8=
gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=

View File

@@ -1,19 +1,19 @@
package alist
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"time"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/types"
)
type Alist struct {
@@ -21,154 +21,72 @@ type Alist struct {
token string
baseURL string
loginInfo *loginRequest
config config.AlistConfig
}
var (
ErrAlistLoginFailed = errors.New("failed to login to Alist")
)
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type loginResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Token string `json:"token"`
} `json:"data"`
}
type meResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
ID int `json:"id"`
Username string `json:"username"`
} `json:"data"`
}
type putResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Task struct {
ID string `json:"id"`
Name string `json:"name"`
State int `json:"state"`
Status string `json:"status"`
Progress int `json:"progress"`
Error string `json:"error"`
} `json:"task"`
} `json:"data"`
}
func (a *Alist) getToken() error {
loginBody, err := json.Marshal(a.loginInfo)
if err != nil {
return fmt.Errorf("failed to marshal login request: %w", err)
func (a *Alist) Init(model types.StorageModel) error {
var alistConfig config.AlistConfig
if err := json.Unmarshal([]byte(model.Config), &alistConfig); err != nil {
return fmt.Errorf("failed to unmarshal alist config: %w", err)
}
req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody))
if err != nil {
return fmt.Errorf("failed to create login request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := a.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send login request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read login response: %w", err)
}
var loginResp loginResponse
if err := json.Unmarshal(body, &loginResp); err != nil {
return fmt.Errorf("failed to unmarshal login response: %w", err)
}
if loginResp.Code != http.StatusOK {
return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message)
}
a.token = loginResp.Data.Token
return nil
}
func (a *Alist) refreshToken() {
for {
time.Sleep(time.Duration(config.Cfg.Storage.Alist.TokenExp) * time.Second)
if err := a.getToken(); err != nil {
logger.L.Errorf("Failed to refresh jwt token: %v", err)
continue
}
logger.L.Info("Refreshed Alist jwt token")
}
}
func (a *Alist) Init() {
a.baseURL = config.Cfg.Storage.Alist.URL
a.client = &http.Client{
Timeout: 12 * time.Hour,
Transport: &http.Transport{
TLSHandshakeTimeout: 10 * time.Second,
},
}
if config.Cfg.Storage.Alist.Token != "" {
a.token = config.Cfg.Storage.Alist.Token
a.config = alistConfig
a.baseURL = alistConfig.URL
a.client = getHttpClient()
if alistConfig.Token != "" {
a.token = alistConfig.Token
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.baseURL+"/api/me", nil)
if err != nil {
logger.L.Fatalf("Failed to create request: %v", err)
os.Exit(1)
return err
}
req.Header.Set("Authorization", a.token)
resp, err := a.client.Do(req)
if err != nil {
logger.L.Fatalf("Failed to send request: %v", err)
os.Exit(1)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.L.Fatalf("Failed to get alist user info: %s", resp.Status)
os.Exit(1)
return err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.L.Fatalf("Failed to read response body: %v", err)
os.Exit(1)
return err
}
var meResp meResponse
if err := json.Unmarshal(body, &meResp); err != nil {
logger.L.Fatalf("Failed to unmarshal me response: %v", err)
os.Exit(1)
return err
}
if meResp.Code != http.StatusOK {
logger.L.Fatalf("Failed to get alist user info: %s", meResp.Message)
os.Exit(1)
return err
}
logger.L.Debugf("Logged in Alist as %s", meResp.Data.Username)
return
return nil
}
a.loginInfo = &loginRequest{
Username: config.Cfg.Storage.Alist.Username,
Password: config.Cfg.Storage.Alist.Password,
Username: alistConfig.Username,
Password: alistConfig.Password,
}
if err := a.getToken(); err != nil {
logger.L.Fatalf("Failed to login to Alist: %v", err)
os.Exit(1)
return err
}
logger.L.Debug("Logged in to Alist")
go a.refreshToken()
go a.refreshToken(alistConfig)
return nil
}
func (a *Alist) Type() types.StorageType {
return types.StorageTypeAlist
}
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
@@ -219,3 +137,7 @@ func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
return nil
}
func (a *Alist) JoinStoragePath(task types.Task) string {
return path.Join(a.config.BasePath, task.StoragePath)
}

60
storage/alist/token.go Normal file
View File

@@ -0,0 +1,60 @@
package alist
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
)
func (a *Alist) getToken() error {
loginBody, err := json.Marshal(a.loginInfo)
if err != nil {
return fmt.Errorf("failed to marshal login request: %w", err)
}
req, err := http.NewRequest(http.MethodPost, a.baseURL+"/api/auth/login", bytes.NewBuffer(loginBody))
if err != nil {
return fmt.Errorf("failed to create login request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := a.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send login request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read login response: %w", err)
}
var loginResp loginResponse
if err := json.Unmarshal(body, &loginResp); err != nil {
return fmt.Errorf("failed to unmarshal login response: %w", err)
}
if loginResp.Code != http.StatusOK {
return fmt.Errorf("%w: %s", ErrAlistLoginFailed, loginResp.Message)
}
a.token = loginResp.Data.Token
return nil
}
func (a *Alist) refreshToken(cfg config.AlistConfig) {
for {
time.Sleep(time.Duration(cfg.TokenExp) * time.Second)
if err := a.getToken(); err != nil {
logger.L.Errorf("Failed to refresh jwt token: %v", err)
continue
}
logger.L.Info("Refreshed Alist jwt token")
}
}

44
storage/alist/types.go Normal file
View File

@@ -0,0 +1,44 @@
package alist
import "errors"
var (
ErrAlistLoginFailed = errors.New("failed to login to Alist")
)
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type loginResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Token string `json:"token"`
} `json:"data"`
}
type meResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
ID int `json:"id"`
Username string `json:"username"`
} `json:"data"`
}
type putResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Task struct {
ID string `json:"id"`
Name string `json:"name"`
State int `json:"state"`
Status string `json:"status"`
Progress int `json:"progress"`
Error string `json:"error"`
} `json:"task"`
} `json:"data"`
}

23
storage/alist/utils.go Normal file
View File

@@ -0,0 +1,23 @@
package alist
import (
"net/http"
"time"
)
var (
httpClient *http.Client
)
func getHttpClient() *http.Client {
if httpClient != nil {
return httpClient
}
httpClient = &http.Client{
Timeout: 12 * time.Hour,
Transport: &http.Transport{
TLSHandshakeTimeout: 10 * time.Second,
},
}
return httpClient
}

View File

@@ -2,22 +2,35 @@ package local
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/types"
)
type Local struct{}
type Local struct {
config config.LocalConfig
}
func (l *Local) Init() {
err := os.MkdirAll(config.Cfg.Storage.Local.BasePath, os.ModePerm)
if err != nil {
logger.L.Fatalf("Failed to create local storage directory: %s", err)
os.Exit(1)
func (l *Local) Init(model types.StorageModel) error {
var localConfig config.LocalConfig
if err := json.Unmarshal([]byte(model.Config), &localConfig); err != nil {
return fmt.Errorf("failed to unmarshal local config: %w", err)
}
l.config = localConfig
err := os.MkdirAll(localConfig.BasePath, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create local storage directory: %w", err)
}
return nil
}
func (l *Local) Type() types.StorageType {
return types.StorageTypeLocal
}
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
@@ -30,3 +43,7 @@ func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
}
return fileutil.CopyFile(filePath, storagePath)
}
func (l *Local) JoinStoragePath(task types.Task) string {
return filepath.Join(l.config.BasePath, task.StoragePath)
}

View File

@@ -3,13 +3,8 @@ package storage
import (
"context"
"errors"
"path"
"path/filepath"
"sync"
"github.com/duke-git/lancet/v2/slice"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage/alist"
"github.com/krau/SaveAny-Bot/storage/local"
"github.com/krau/SaveAny-Bot/storage/webdav"
@@ -17,68 +12,72 @@ import (
)
type Storage interface {
Init()
Init(model types.StorageModel) error
Type() types.StorageType
JoinStoragePath(task types.Task) string
Save(cttx context.Context, localFilePath, storagePath string) error
}
var Storages = make(map[types.StorageType]Storage)
var StorageKeys = make([]types.StorageType, 0)
var (
ErrInvalidStorageID = errors.New("invalid storage ID")
)
func Init() {
logger.L.Debug("Initializing storage...")
if config.Cfg.Storage.Alist.Enable {
Storages[types.Alist] = new(alist.Alist)
Storages[types.Alist].Init()
}
if config.Cfg.Storage.Local.Enable {
Storages[types.Local] = new(local.Local)
Storages[types.Local].Init()
}
if config.Cfg.Storage.Webdav.Enable {
Storages[types.Webdav] = new(webdav.Webdav)
Storages[types.Webdav].Init()
}
var Storages = make(map[uint]Storage)
for k := range Storages {
StorageKeys = append(StorageKeys, k)
// LoadExistingStorages loads existing storages from the database, and initializes them
//
// Should only be called at startup
func LoadExistingStorages() error {
storageModels, err := dao.GetActiveStorages()
if err != nil {
return err
}
slice.Sort(StorageKeys)
logger.L.Debug("Storage initialized")
}
func Save(storageType types.StorageType, ctx context.Context, filePath, storagePath string) error {
logger.L.Debugf("Saving file %s to storage: [%s] %s", filePath, storageType, storagePath)
if ctx == nil {
ctx = context.Background()
}
if storageType != types.StorageAll {
return Storages[storageType].Save(ctx, filePath, storagePath)
}
errs := make([]error, 0)
var wg sync.WaitGroup
for _, storage := range Storages {
wg.Add(1)
go func(storage Storage) {
defer wg.Done()
storageDestPath := storagePath
switch storage.(type) {
case *local.Local:
storageDestPath = filepath.Join(config.Cfg.Storage.Local.BasePath, storagePath)
case *webdav.Webdav:
storageDestPath = path.Join(config.Cfg.Storage.Webdav.BasePath, storagePath)
case *alist.Alist:
storageDestPath = path.Join(config.Cfg.Storage.Alist.BasePath, storagePath)
}
if err := storage.Save(ctx, filePath, storageDestPath); err != nil {
errs = append(errs, err)
}
}(storage)
}
wg.Wait()
if len(errs) > 0 {
return errors.Join(errs...)
for _, storageModel := range storageModels {
storage, err := NewStorage(storageModel)
if err != nil {
return err
}
Storages[storageModel.ID] = storage
}
return nil
}
// Get storage from model, if it exists, otherwise create and init a new storage
func GetStorageFromModel(model types.StorageModel) (Storage, error) {
if model.ID == 0 {
return nil, ErrInvalidStorageID
}
if storage, ok := Storages[model.ID]; ok {
return storage, nil
}
storage, err := NewStorage(model)
if err != nil {
return nil, err
}
Storages[model.ID] = storage
return storage, nil
}
func NewStorage(storageModel types.StorageModel) (Storage, error) {
switch storageModel.Type {
case string(types.StorageTypeAlist):
alistStorage := new(alist.Alist)
if err := alistStorage.Init(storageModel); err != nil {
return nil, err
}
return alistStorage, nil
case string(types.StorageTypeLocal):
localStorage := new(local.Local)
if err := localStorage.Init(storageModel); err != nil {
return nil, err
}
return localStorage, nil
case string(types.StorageTypeWebdav):
webdavStorage := new(webdav.Webdav)
if err := webdavStorage.Init(storageModel); err != nil {
return nil, err
}
return webdavStorage, nil
}
return nil, nil
}

View File

@@ -2,29 +2,42 @@ package webdav
import (
"context"
"encoding/json"
"fmt"
"os"
"path"
"time"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/types"
"github.com/studio-b12/gowebdav"
)
type Webdav struct{}
type Webdav struct {
config config.WebdavConfig
}
var (
Client *gowebdav.Client
)
func (w *Webdav) Init() {
webdavConfig := config.Cfg.Storage.Webdav
func (w *Webdav) Init(model types.StorageModel) error {
var webdavConfig config.WebdavConfig
if err := json.Unmarshal([]byte(model.Config), &webdavConfig); err != nil {
return fmt.Errorf("failed to unmarshal webdav config: %w", err)
}
w.config = webdavConfig
Client = gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
if err := Client.Connect(); err != nil {
logger.L.Fatalf("Failed to connect to webdav server: %v", err)
os.Exit(1)
return fmt.Errorf("failed to connect to webdav server: %w", err)
}
Client.SetTimeout(24 * time.Hour)
Client.SetTimeout(12 * time.Hour)
return nil
}
func (w *Webdav) Type() types.StorageType {
return types.StorageTypeWebdav
}
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
@@ -45,3 +58,7 @@ func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
}
return nil
}
func (w *Webdav) JoinStoragePath(task types.Task) string {
return path.Join(w.config.BasePath, task.StoragePath)
}

View File

@@ -1,14 +1,17 @@
package types
import (
"gorm.io/datatypes"
"gorm.io/gorm"
)
type ReceivedFile struct {
gorm.Model
Processing bool
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
Processing bool
// Which chat the file is from
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
// Which message the file is from
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
ReplyMessageID int
ReplyChatID int64
FileName string
@@ -16,7 +19,18 @@ type ReceivedFile struct {
type User struct {
gorm.Model
UserID int64 `gorm:"uniqueIndex"`
Silent bool
DefaultStorage string
ChatID int64 `gorm:"uniqueIndex"` // Telegram user ID
Silent bool
DefaultStorageID uint
Storages []*StorageModel `gorm:"many2many:user_storages;"`
}
type StorageModel struct {
gorm.Model
Type string
Name string // just for display
Desc string
Active bool
Config datatypes.JSON
Users []*User `gorm:"many2many:user_storages;"`
}

View File

@@ -22,20 +22,20 @@ var (
type StorageType string
var (
StorageAll StorageType = "all"
Local StorageType = "local"
Webdav StorageType = "webdav"
Alist StorageType = "alist"
StorageAll StorageType = "all"
StorageTypeLocal StorageType = "local"
StorageTypeWebdav StorageType = "webdav"
StorageTypeAlist StorageType = "alist"
)
var StorageTypes = []StorageType{Local, Alist, Webdav, StorageAll}
var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav, StorageAll}
type Task struct {
Ctx context.Context
Error error
Status TaskStatus
File *File
Storage StorageType
StorageID uint
StoragePath string
StartTime time.Time