feat!: (WIP) switched back to using config files config storages because the conversation handling is shit
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
@@ -9,7 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func InitAll() {
|
||||
config.Init()
|
||||
if err := config.Init(); err != nil {
|
||||
fmt.Println("Failed to init config: ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.InitLogger()
|
||||
logger.L.Info("Starting SaveAny-Bot...")
|
||||
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func manageStorageEntry(ctx *ext.Context, update *ext.Update) error {
|
||||
user, err := dao.GetUserByChatID(update.EffectiveUser().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user active storages: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户存储失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
state, ok := userConversationState[user.ChatID]
|
||||
if !ok {
|
||||
state = &ConversationState{}
|
||||
userConversationState[user.ChatID] = state
|
||||
}
|
||||
state.Reset()
|
||||
state.InConversation = true
|
||||
state.SetConversationType(ConversationTypeManageStorage)
|
||||
state.SetData("status", "entry")
|
||||
|
||||
storagesMsg := "已添加的存储:"
|
||||
if len(user.Storages) == 0 {
|
||||
storagesMsg += " 无"
|
||||
} else {
|
||||
for i, storage := range user.Storages {
|
||||
storagesMsg += fmt.Sprintf("\n%d. %s", i+1, storage.Name)
|
||||
}
|
||||
}
|
||||
storagesMsg += "\n\n请选择操作:"
|
||||
_, err = ctx.Reply(update, ext.ReplyTextString(storagesMsg), &ext.ReplyOpts{
|
||||
Markup: &manageStorageKeyboardMarkup,
|
||||
})
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to send manage storage message: %s", err)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func handleManageStorageConversation(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
status := state.GetData("status").(string)
|
||||
switch status {
|
||||
case "entry":
|
||||
return manageStorageMenu(ctx, update, state)
|
||||
case "add_select_type":
|
||||
return manageStorageAddSelectType(ctx, update, state)
|
||||
case "selected_add_type":
|
||||
return manageStorageAddSelectedType(ctx, update, state)
|
||||
default:
|
||||
logger.L.Errorf("Unknown manage storage status: %s", status)
|
||||
}
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func manageStorageMenu(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
text := update.EffectiveMessage.Text
|
||||
switch text {
|
||||
case manageStorageButtonAdd:
|
||||
return manageStorageAdd(ctx, update, state)
|
||||
case manageStorageButtonDelete:
|
||||
return manageStorageDelete(ctx, update)
|
||||
case manageStorageButtonEdit:
|
||||
return manageStorageEdit(ctx, update)
|
||||
case manageStorageButtonSetDefault:
|
||||
return manageStorageSetDefault(ctx, update)
|
||||
default:
|
||||
logger.L.Errorf("Unknown manage storage button: %s", text)
|
||||
ctx.Reply(update, ext.ReplyTextString("未知操作"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
}
|
||||
|
||||
func manageStorageAdd(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
rows := make([]tg.KeyboardButtonRow, 0)
|
||||
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||
for i, storageType := range types.StorageTypes {
|
||||
buttons = append(buttons, &tg.KeyboardButton{
|
||||
Text: types.StorageTypeDisplay[storageType],
|
||||
})
|
||||
if (i+1)%3 == 0 || i == len(types.StorageTypes)-1 {
|
||||
rows = append(rows, tg.KeyboardButtonRow{
|
||||
Buttons: buttons,
|
||||
})
|
||||
buttons = make([]tg.KeyboardButtonClass, 0)
|
||||
}
|
||||
}
|
||||
manageStorageAddKeyboardMarkup := tg.ReplyKeyboardMarkup{
|
||||
Selective: true,
|
||||
Resize: true,
|
||||
Rows: rows,
|
||||
}
|
||||
|
||||
state.SetData("status", "add_select_type")
|
||||
|
||||
ctx.Reply(update, ext.ReplyTextString("请选择要添加的存储类型"), &ext.ReplyOpts{
|
||||
Markup: &manageStorageAddKeyboardMarkup,
|
||||
})
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
|
||||
func manageStorageAddSelectType(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
text := update.EffectiveMessage.Text
|
||||
var storageType types.StorageType
|
||||
for t, display := range types.StorageTypeDisplay {
|
||||
if display == text {
|
||||
storageType = t
|
||||
break
|
||||
}
|
||||
}
|
||||
if storageType == "" {
|
||||
ctx.Reply(update, ext.ReplyTextString("未知的存储类型"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
state.SetData("status", "selected_add_type")
|
||||
state.SetData("storage_type", storageType)
|
||||
return manageStorageAddSelectedType(ctx, update, state)
|
||||
}
|
||||
|
||||
func manageStorageAddSelectedType(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
selectedType := state.GetData("storage_type").(types.StorageType)
|
||||
configItems := storage.GetStorageConfigurableItems(selectedType)
|
||||
configIndexData := state.GetData("configindex")
|
||||
configIndex := 0
|
||||
if configIndexData == nil {
|
||||
state.SetData("configindex", configIndex)
|
||||
} else {
|
||||
configIndex = configIndexData.(int)
|
||||
if update.EffectiveMessage.Text != "" {
|
||||
logger.L.Debugf("config %s: %s", configItems[configIndex-1], update.EffectiveMessage.Text)
|
||||
state.SetData(configItems[configIndex-1], update.EffectiveMessage.Text)
|
||||
}
|
||||
}
|
||||
if configIndex >= len(configItems) {
|
||||
// TODO: save storage
|
||||
state.SetData("status", "add_complete")
|
||||
logger.L.Infof("Save storage")
|
||||
return manageStorageSave(ctx, update, state)
|
||||
}
|
||||
|
||||
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("正在配置 %s 存储...\n请提供 %s", types.StorageTypeDisplay[selectedType], configItems[configIndex])), &ext.ReplyOpts{
|
||||
Markup: &tg.ReplyKeyboardForceReply{
|
||||
Selective: true,
|
||||
SingleUse: true,
|
||||
},
|
||||
})
|
||||
state.SetData("configindex", configIndex+1)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func manageStorageSave(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.L.Errorf("Failed to save storage: %s", r)
|
||||
ctx.Reply(update, ext.ReplyTextString("存储配置失败"), nil)
|
||||
}
|
||||
state.Reset()
|
||||
}()
|
||||
storageType := state.GetData("storage_type").(types.StorageType)
|
||||
config := make(map[string]string)
|
||||
configItems := storage.GetStorageConfigurableItems(storageType)
|
||||
for _, item := range configItems {
|
||||
config[item] = state.GetData(item).(string)
|
||||
}
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to marshal storage config: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("存储配置失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
user, err := dao.GetUserByChatID(update.EffectiveUser().GetID())
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get user: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
storageModel := types.StorageModel{
|
||||
Type: string(storageType),
|
||||
Active: true,
|
||||
Config: configJSON,
|
||||
}
|
||||
hash := storageModel.GenHash()
|
||||
storageModel.Hash = hash
|
||||
if storagedb, err := dao.GetStorageByHash(hash); err == nil {
|
||||
logger.L.Debugf("Storage already exists")
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
} else {
|
||||
if id, err := dao.CreateStorage(&storageModel); err != nil {
|
||||
logger.L.Errorf("Failed to create storage: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("存储创建失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
} else {
|
||||
storagedb := &types.StorageModel{}
|
||||
storagedb.ID = id
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
}
|
||||
}
|
||||
if err := dao.UpdateUser(user); err != nil {
|
||||
logger.L.Errorf("Failed to update user with storages: %s", err)
|
||||
ctx.Reply(update, ext.ReplyTextString("用户更新失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
ctx.Reply(update, ext.ReplyTextString("存储已添加"), &ext.ReplyOpts{
|
||||
Markup: &tg.ReplyKeyboardHide{},
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func manageStorageDelete(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
|
||||
func manageStorageEdit(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
|
||||
func manageStorageSetDefault(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
@@ -53,7 +54,9 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if len(user.Storages) == 0 {
|
||||
storages := storage.GetUserStorages(user.ChatID)
|
||||
|
||||
if len(storages) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -91,14 +94,14 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if !user.Silent || user.DefaultStorageID == 0 {
|
||||
return ProvideSelectMessage(ctx, update, file, int(linkChat.GetID()), messageID, replied.ID)
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, file, linkChat.GetID(), messageID, replied.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
File: file,
|
||||
StorageID: user.DefaultStorageID,
|
||||
StorageName: user.DefaultStorage,
|
||||
FileChatID: linkChat.GetID(),
|
||||
FileMessageID: messageID,
|
||||
ReplyMessageID: replied.ID,
|
||||
|
||||
@@ -2,18 +2,10 @@ package bot
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
)
|
||||
|
||||
type ConversationType string
|
||||
|
||||
const (
|
||||
ConversationTypeManageStorage ConversationType = "manage_storage"
|
||||
)
|
||||
|
||||
type ConversationState struct {
|
||||
sync.Mutex
|
||||
conversationType ConversationType
|
||||
@@ -54,31 +46,30 @@ func (c *ConversationState) SetData(key string, value interface{}) {
|
||||
c.data[c.conversationType][key] = value
|
||||
}
|
||||
|
||||
var userConversationState = make(map[int64]*ConversationState)
|
||||
// TODO: Implement conversation handling
|
||||
// var userConversationState = make(map[int64]*ConversationState)
|
||||
|
||||
func handleConversation(ctx *ext.Context, update *ext.Update) error {
|
||||
userID := update.EffectiveUser().GetID()
|
||||
state, ok := userConversationState[userID]
|
||||
if !ok {
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
if update.EffectiveMessage.Text == "/cancel" {
|
||||
state.Reset()
|
||||
ctx.Reply(update, ext.ReplyTextString("已取消"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if !state.InConversation {
|
||||
return dispatcher.ContinueGroups
|
||||
}
|
||||
return handleConversationState(ctx, update, state)
|
||||
}
|
||||
// func handleConversation(ctx *ext.Context, update *ext.Update) error {
|
||||
// userID := update.EffectiveUser().GetID()
|
||||
// state, ok := userConversationState[userID]
|
||||
// if !ok {
|
||||
// return dispatcher.ContinueGroups
|
||||
// }
|
||||
// if update.EffectiveMessage.Text == "/cancel" {
|
||||
// state.Reset()
|
||||
// ctx.Reply(update, ext.ReplyTextString("已取消"), nil)
|
||||
// return dispatcher.EndGroups
|
||||
// }
|
||||
// if !state.InConversation {
|
||||
// return dispatcher.ContinueGroups
|
||||
// }
|
||||
// return handleConversationState(ctx, update, state)
|
||||
// }
|
||||
|
||||
func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
switch state.conversationType {
|
||||
case ConversationTypeManageStorage:
|
||||
return handleManageStorageConversation(ctx, update, state)
|
||||
default:
|
||||
logger.L.Errorf("Unknown conversation type: %s", state.conversationType)
|
||||
}
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
// func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
|
||||
// switch state.conversationType {
|
||||
// default:
|
||||
// logger.L.Errorf("Unknown conversation type: %s", state.conversationType)
|
||||
// }
|
||||
// return dispatcher.EndGroups
|
||||
// }
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
@@ -26,7 +27,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
||||
dispatcher.AddHandler(handlers.NewCommand("start", start))
|
||||
dispatcher.AddHandler(handlers.NewCommand("help", help))
|
||||
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
|
||||
dispatcher.AddHandler(handlers.NewCommand("storage", manageStorageEntry))
|
||||
dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd))
|
||||
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
|
||||
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
|
||||
if err != nil {
|
||||
@@ -35,7 +36,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
|
||||
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
|
||||
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
|
||||
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
|
||||
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation))
|
||||
// dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation))
|
||||
}
|
||||
|
||||
const noPermissionText string = `
|
||||
@@ -120,7 +121,10 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if len(user.Storages) == 0 {
|
||||
|
||||
storages := storage.GetUserStorages(user.ChatID)
|
||||
|
||||
if len(storages) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -180,14 +184,14 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
}
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if !user.Silent || user.DefaultStorageID == 0 {
|
||||
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), msg.ID, replied.ID)
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
File: file,
|
||||
StorageID: user.DefaultStorageID,
|
||||
StorageName: user.DefaultStorage,
|
||||
FileChatID: update.EffectiveChat().GetID(),
|
||||
ReplyMessageID: replied.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
@@ -195,6 +199,11 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
})
|
||||
}
|
||||
|
||||
func storageCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
// TODO: Implement
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
|
||||
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)
|
||||
@@ -211,7 +220,8 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
if len(user.Storages) == 0 {
|
||||
storages := storage.GetUserStorages(user.ChatID)
|
||||
if len(storages) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -250,14 +260,14 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
if !user.Silent || user.DefaultStorageID == 0 {
|
||||
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), update.EffectiveMessage.ID, msg.ID)
|
||||
if !user.Silent || user.DefaultStorage == "" {
|
||||
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
|
||||
}
|
||||
return HandleSilentAddTask(ctx, update, user, &types.Task{
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
File: file,
|
||||
StorageID: user.DefaultStorageID,
|
||||
StorageName: user.DefaultStorage,
|
||||
FileChatID: update.EffectiveChat().GetID(),
|
||||
ReplyMessageID: msg.ID,
|
||||
ReplyChatID: update.GetUserChat().GetID(),
|
||||
@@ -278,8 +288,19 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
args := strings.Split(string(update.CallbackQuery.Data), " ")
|
||||
fileChatID, _ := strconv.Atoi(args[1])
|
||||
fileMessageID, _ := strconv.Atoi(args[2])
|
||||
storageID, _ := strconv.Atoi(args[3])
|
||||
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storageID: %d", fileChatID, fileMessageID, storageID)
|
||||
storageNameHash := args[3]
|
||||
storageName := storageHashName[storageNameHash]
|
||||
if storageName == "" {
|
||||
logger.L.Errorf("Unknown storage name hash: %d", storageNameHash)
|
||||
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
|
||||
QueryID: update.CallbackQuery.QueryID,
|
||||
Alert: true,
|
||||
Message: "未知存储位置",
|
||||
CacheTime: 5,
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", fileChatID, fileMessageID, storageName)
|
||||
record, err := dao.GetReceivedFileByChatAndMessageID(int64(fileChatID), fileMessageID)
|
||||
if err != nil {
|
||||
logger.L.Errorf("Failed to get received file: %s", err)
|
||||
@@ -313,7 +334,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
|
||||
Ctx: ctx,
|
||||
Status: types.Pending,
|
||||
File: file,
|
||||
StorageID: uint(storageID),
|
||||
StorageName: storageName,
|
||||
FileChatID: record.ChatID,
|
||||
ReplyMessageID: record.ReplyMessageID,
|
||||
FileMessageID: record.MessageID,
|
||||
|
||||
56
bot/utils.go
56
bot/utils.go
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
@@ -25,39 +26,6 @@ var (
|
||||
ErrNoStorages = errors.New("no available storage")
|
||||
)
|
||||
|
||||
var (
|
||||
manageStorageButtonAdd = "添加存储"
|
||||
manageStorageButtonDelete = "删除存储"
|
||||
manageStorageButtonEdit = "修改存储"
|
||||
manageStorageButtonSetDefault = "设置默认存储"
|
||||
manageStorageKeyboardMarkup = tg.ReplyKeyboardMarkup{
|
||||
Selective: true,
|
||||
Resize: true,
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
&tg.KeyboardButton{
|
||||
Text: manageStorageButtonAdd,
|
||||
},
|
||||
&tg.KeyboardButton{
|
||||
Text: manageStorageButtonDelete,
|
||||
},
|
||||
&tg.KeyboardButton{
|
||||
Text: manageStorageButtonEdit,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
&tg.KeyboardButton{
|
||||
Text: manageStorageButtonSetDefault,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func supportedMediaFilter(m *tg.Message) (bool, error) {
|
||||
if not := m.Media == nil; not {
|
||||
return false, dispatcher.EndGroups
|
||||
@@ -72,19 +40,23 @@ func supportedMediaFilter(m *tg.Message) (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// for callback data
|
||||
var storageHashName = map[string]string{}
|
||||
|
||||
func getSelectStorageMarkup(userChatID int64, fileChatID, fileMessageID int) (*tg.ReplyInlineMarkup, error) {
|
||||
user, err := dao.GetUserByChatID(userChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(user.Storages) < 1 {
|
||||
return nil, ErrNoStorages
|
||||
}
|
||||
storages := storage.GetUserStorages(user.ChatID)
|
||||
|
||||
buttons := make([]tg.KeyboardButtonClass, 0)
|
||||
for _, storage := range user.Storages {
|
||||
for _, storage := range storages {
|
||||
nameHash := common.HashString(storage.Name())
|
||||
storageHashName[nameHash] = storage.Name()
|
||||
buttons = append(buttons, &tg.KeyboardButtonCallback{
|
||||
Text: storage.Name,
|
||||
Data: []byte(fmt.Sprintf("add %d %d %d", fileChatID, fileMessageID, storage.ID)),
|
||||
Text: storage.Name(),
|
||||
Data: []byte(fmt.Sprintf("add %d %d %s", fileChatID, fileMessageID, nameHash)),
|
||||
})
|
||||
}
|
||||
markup := &tg.ReplyInlineMarkup{}
|
||||
@@ -194,7 +166,7 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
|
||||
return tgMessage, nil
|
||||
}
|
||||
|
||||
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int, fileMsgID, toEditMsgID int) error {
|
||||
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int64, fileMsgID, toEditMsgID int) error {
|
||||
entityBuilder := entity.Builder{}
|
||||
var entities []tg.MessageEntityClass
|
||||
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
|
||||
@@ -207,7 +179,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
|
||||
} else {
|
||||
text, entities = entityBuilder.Complete()
|
||||
}
|
||||
markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), chatID, fileMsgID)
|
||||
markup, err := getSelectStorageMarkup(update.EffectiveUser().GetID(), int(chatID), fileMsgID)
|
||||
if errors.Is(err, ErrNoStorages) {
|
||||
logger.L.Errorf("Failed to get select storage markup: %s", err)
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
@@ -236,7 +208,7 @@ func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File
|
||||
}
|
||||
|
||||
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User, task *types.Task) error {
|
||||
if user.DefaultStorageID == 0 {
|
||||
if user.DefaultStorage == "" {
|
||||
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
|
||||
Message: "请先使用 /storage 设置默认存储位置",
|
||||
ID: task.ReplyMessageID,
|
||||
|
||||
12
common/utils.go
Normal file
12
common/utils.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func HashString(s string) string {
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(s))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
95
config/deprecated.go
Normal file
95
config/deprecated.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
// for compatibility
|
||||
type deprecatedStorageConfig struct {
|
||||
Alist alistConfig `toml:"alist" mapstructure:"alist"`
|
||||
Local localConfig `toml:"local" mapstructure:"local"`
|
||||
Webdav webdavConfig `toml:"webdav" mapstructure:"webdav"`
|
||||
}
|
||||
|
||||
type alistConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||
}
|
||||
|
||||
func (a *alistConfig) ToJSON() datatypes.JSON {
|
||||
tokenExp := strconv.FormatInt(a.TokenExp, 10)
|
||||
return datatypes.JSON([]byte(`{"url":"` + a.URL + `","username":"` + a.Username + `","password":"` + a.Password + `","token":"` + a.Token + `","base_path":"` + a.BasePath + `","token_exp":` + tokenExp + `}`))
|
||||
}
|
||||
|
||||
type localConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (l *localConfig) ToJSON() datatypes.JSON {
|
||||
return datatypes.JSON([]byte(`{"base_path":"` + l.BasePath + `"}`))
|
||||
}
|
||||
|
||||
type webdavConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (w *webdavConfig) ToJSON() datatypes.JSON {
|
||||
return datatypes.JSON([]byte(`{"url":"` + w.URL + `","username":"` + w.Username + `","password":"` + w.Password + `","base_path":"` + w.BasePath + `"}`))
|
||||
}
|
||||
|
||||
func transformDeprecatedStorageConfig() {
|
||||
if Cfg.DeprecatedStorage.Alist.Enable {
|
||||
alistStorage := &AlistStorageConfig{
|
||||
NewStorageConfig: NewStorageConfig{
|
||||
Name: "Alist",
|
||||
Enable: true,
|
||||
Type: string(types.StorageTypeAlist),
|
||||
},
|
||||
URL: Cfg.DeprecatedStorage.Alist.URL,
|
||||
Username: Cfg.DeprecatedStorage.Alist.Username,
|
||||
Password: Cfg.DeprecatedStorage.Alist.Password,
|
||||
Token: Cfg.DeprecatedStorage.Alist.Token,
|
||||
BasePath: Cfg.DeprecatedStorage.Alist.BasePath,
|
||||
TokenExp: Cfg.DeprecatedStorage.Alist.TokenExp,
|
||||
}
|
||||
Cfg.Storages = append(Cfg.Storages, alistStorage)
|
||||
}
|
||||
if Cfg.DeprecatedStorage.Local.Enable {
|
||||
localStorage := &LocalStorageConfig{
|
||||
NewStorageConfig: NewStorageConfig{
|
||||
Name: "Local",
|
||||
Enable: true,
|
||||
Type: string(types.StorageTypeLocal),
|
||||
},
|
||||
BasePath: Cfg.DeprecatedStorage.Local.BasePath,
|
||||
}
|
||||
Cfg.Storages = append(Cfg.Storages, localStorage)
|
||||
}
|
||||
if Cfg.DeprecatedStorage.Webdav.Enable {
|
||||
webdavStorage := &WebdavStorageConfig{
|
||||
NewStorageConfig: NewStorageConfig{
|
||||
Name: "Webdav",
|
||||
Enable: true,
|
||||
Type: string(types.StorageTypeWebdav),
|
||||
},
|
||||
URL: Cfg.DeprecatedStorage.Webdav.URL,
|
||||
Username: Cfg.DeprecatedStorage.Webdav.Username,
|
||||
Password: Cfg.DeprecatedStorage.Webdav.Password,
|
||||
BasePath: Cfg.DeprecatedStorage.Webdav.BasePath,
|
||||
}
|
||||
Cfg.Storages = append(Cfg.Storages, webdavStorage)
|
||||
}
|
||||
}
|
||||
104
config/storage_factory.go
Normal file
104
config/storage_factory.go
Normal file
@@ -0,0 +1,104 @@
|
||||
// storage_config.go
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type StorageConfig interface {
|
||||
Validate() error
|
||||
GetType() types.StorageType
|
||||
GetName() string
|
||||
}
|
||||
|
||||
// Base storage config
|
||||
type NewStorageConfig struct {
|
||||
Name string `toml:"name" mapstructure:"name" json:"name"`
|
||||
Type string `toml:"type" mapstructure:"type" json:"type"`
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
RawConfig map[string]interface{} `toml:"-" mapstructure:",remain"`
|
||||
}
|
||||
|
||||
type StorageConfigFactory func(cfg *NewStorageConfig) (StorageConfig, error)
|
||||
|
||||
var storageFactories = make(map[string]StorageConfigFactory)
|
||||
|
||||
func RegisterStorageFactory(storageType string, factory StorageConfigFactory) {
|
||||
storageFactories[storageType] = factory
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterStorageFactory(string(types.StorageTypeLocal), newLocalStorageConfig)
|
||||
RegisterStorageFactory(string(types.StorageTypeAlist), newAlistStorageConfig)
|
||||
RegisterStorageFactory(string(types.StorageTypeWebdav), newWebdavStorageConfig)
|
||||
}
|
||||
|
||||
func newLocalStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||
var localCfg LocalStorageConfig
|
||||
localCfg.NewStorageConfig = *cfg
|
||||
|
||||
if err := mapstructure.Decode(cfg.RawConfig, &localCfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode local storage config: %w", err)
|
||||
}
|
||||
|
||||
return &localCfg, nil
|
||||
}
|
||||
|
||||
func newAlistStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||
var alistCfg AlistStorageConfig
|
||||
alistCfg.NewStorageConfig = *cfg
|
||||
|
||||
if err := mapstructure.Decode(cfg.RawConfig, &alistCfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode alist storage config: %w", err)
|
||||
}
|
||||
|
||||
return &alistCfg, nil
|
||||
}
|
||||
|
||||
func newWebdavStorageConfig(cfg *NewStorageConfig) (StorageConfig, error) {
|
||||
var webdavCfg WebdavStorageConfig
|
||||
webdavCfg.NewStorageConfig = *cfg
|
||||
|
||||
if err := mapstructure.Decode(cfg.RawConfig, &webdavCfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode webdav storage config: %w", err)
|
||||
}
|
||||
|
||||
return &webdavCfg, nil
|
||||
}
|
||||
|
||||
func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) {
|
||||
var baseConfigs []NewStorageConfig
|
||||
if err := v.UnmarshalKey("storages", &baseConfigs); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal storage configs: %w", err)
|
||||
}
|
||||
|
||||
var configs []StorageConfig
|
||||
for _, baseCfg := range baseConfigs {
|
||||
if !baseCfg.Enable {
|
||||
continue
|
||||
}
|
||||
|
||||
factory, ok := storageFactories[baseCfg.Type]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", baseCfg.Type)
|
||||
}
|
||||
|
||||
cfg, err := factory(&baseCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage config for %s: %w", baseCfg.Name, err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid storage config for %s: %w", baseCfg.Name, err)
|
||||
}
|
||||
|
||||
configs = append(configs, cfg)
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
106
config/storages.go
Normal file
106
config/storages.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func (c *Config) GetStoragesByType(storageType types.StorageType) []StorageConfig {
|
||||
var storages []StorageConfig
|
||||
for _, storage := range c.Storages {
|
||||
if storage.GetType() == storageType {
|
||||
storages = append(storages, storage)
|
||||
}
|
||||
}
|
||||
return storages
|
||||
}
|
||||
|
||||
func (c *Config) GetStorageByName(name string) StorageConfig {
|
||||
for _, storage := range c.Storages {
|
||||
if storage.GetName() == name {
|
||||
return storage
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type LocalStorageConfig struct {
|
||||
NewStorageConfig
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) Validate() error {
|
||||
if l.BasePath == "" {
|
||||
return fmt.Errorf("path is required for local storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeLocal
|
||||
}
|
||||
|
||||
func (l *LocalStorageConfig) GetName() string {
|
||||
return l.Name
|
||||
}
|
||||
|
||||
type AlistStorageConfig struct {
|
||||
NewStorageConfig
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) Validate() error {
|
||||
if a.URL == "" {
|
||||
return fmt.Errorf("url is required for alist storage")
|
||||
}
|
||||
if a.Token == "" && (a.Username == "" || a.Password == "") {
|
||||
return fmt.Errorf("username and password or token is required for alist storage")
|
||||
}
|
||||
if a.BasePath == "" {
|
||||
return fmt.Errorf("base_path is required for alist storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeAlist
|
||||
}
|
||||
|
||||
func (a *AlistStorageConfig) GetName() string {
|
||||
return a.Name
|
||||
}
|
||||
|
||||
type WebdavStorageConfig struct {
|
||||
NewStorageConfig
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) Validate() error {
|
||||
if w.URL == "" {
|
||||
return fmt.Errorf("url is required for webdav storage")
|
||||
}
|
||||
if w.Username == "" || w.Password == "" {
|
||||
return fmt.Errorf("username and password is required for webdav storage")
|
||||
}
|
||||
if w.BasePath == "" {
|
||||
return fmt.Errorf("base_path is required for webdav storage")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) GetType() types.StorageType {
|
||||
return types.StorageTypeWebdav
|
||||
}
|
||||
|
||||
func (w *WebdavStorageConfig) GetName() string {
|
||||
return w.Name
|
||||
}
|
||||
28
config/user.go
Normal file
28
config/user.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
)
|
||||
|
||||
type userConfig struct {
|
||||
ID int64 `toml:"id" mapstructure:"id" json:"id"` // telegram user id
|
||||
Storages []string `toml:"storages" mapstructure:"storages" json:"storages"` // storage names
|
||||
Blacklist bool `toml:"blacklist" mapstructure:"blacklist" json:"blacklist"` // 黑名单模式, storage names 中的存储将不会被使用, 默认为白名单模式
|
||||
}
|
||||
|
||||
func (c *Config) GetStorageNamesByUserID(userID int64) []string {
|
||||
for _, user := range c.Users {
|
||||
if user.ID == userID {
|
||||
if user.Blacklist {
|
||||
allStorages := make([]string, 0, len(c.Storages))
|
||||
for _, storage := range c.Storages {
|
||||
allStorages = append(allStorages, storage.GetName())
|
||||
}
|
||||
return slice.Compact(slice.Difference(allStorages, user.Storages))
|
||||
} else {
|
||||
return user.Storages
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
132
config/viper.go
132
config/viper.go
@@ -3,23 +3,24 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Workers int `toml:"workers" mapstructure:"workers"`
|
||||
Retry int `toml:"retry" mapstructure:"retry"`
|
||||
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
|
||||
Workers int `toml:"workers" mapstructure:"workers"`
|
||||
Retry int `toml:"retry" mapstructure:"retry"`
|
||||
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
|
||||
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
|
||||
|
||||
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
||||
Log logConfig `toml:"log" mapstructure:"log"`
|
||||
DB dbConfig `toml:"db" mapstructure:"db"`
|
||||
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
|
||||
Storage storageConfig `toml:"storage" mapstructure:"storage"`
|
||||
Temp tempConfig `toml:"temp" mapstructure:"temp"`
|
||||
Log logConfig `toml:"log" mapstructure:"log"`
|
||||
DB dbConfig `toml:"db" mapstructure:"db"`
|
||||
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
|
||||
Storages []StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
|
||||
// Deprecated
|
||||
DeprecatedStorage deprecatedStorageConfig `toml:"storage" mapstructure:"storage"`
|
||||
}
|
||||
|
||||
type tempConfig struct {
|
||||
@@ -38,12 +39,13 @@ type dbConfig struct {
|
||||
}
|
||||
|
||||
type telegramConfig struct {
|
||||
Token string `toml:"token" mapstructure:"token"`
|
||||
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
|
||||
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
|
||||
// 白名单用户
|
||||
Admins []int64 `toml:"admins" mapstructure:"admins"` // Whitelisted users
|
||||
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
|
||||
Token string `toml:"token" mapstructure:"token"`
|
||||
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
|
||||
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
|
||||
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
|
||||
|
||||
// Deprecated
|
||||
Admins []int64 `toml:"admins" mapstructure:"admins"`
|
||||
}
|
||||
|
||||
type proxyConfig struct {
|
||||
@@ -51,56 +53,9 @@ type proxyConfig struct {
|
||||
URL string `toml:"url" mapstructure:"url"`
|
||||
}
|
||||
|
||||
// pre-defined storages, for compatibility.
|
||||
/*
|
||||
在配置文件中定义的存储将会为telegram.admins中的每个用户创建一个存储模型
|
||||
*/
|
||||
// these config will be removed in the future.
|
||||
type storageConfig struct {
|
||||
Alist AlistConfig `toml:"alist" mapstructure:"alist"`
|
||||
Local LocalConfig `toml:"local" mapstructure:"local"`
|
||||
Webdav WebdavConfig `toml:"webdav" mapstructure:"webdav"`
|
||||
}
|
||||
|
||||
type AlistConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
TokenExp int64 `toml:"token_exp" mapstructure:"token_exp" json:"token_exp"`
|
||||
}
|
||||
|
||||
func (a *AlistConfig) ToJSON() datatypes.JSON {
|
||||
tokenExp := strconv.FormatInt(a.TokenExp, 10)
|
||||
return datatypes.JSON([]byte(`{"url":"` + a.URL + `","username":"` + a.Username + `","password":"` + a.Password + `","token":"` + a.Token + `","base_path":"` + a.BasePath + `","token_exp":` + tokenExp + `}`))
|
||||
}
|
||||
|
||||
type LocalConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (l *LocalConfig) ToJSON() datatypes.JSON {
|
||||
return datatypes.JSON([]byte(`{"base_path":"` + l.BasePath + `"}`))
|
||||
}
|
||||
|
||||
type WebdavConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
URL string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Username string `toml:"username" mapstructure:"username" json:"username"`
|
||||
Password string `toml:"password" mapstructure:"password" json:"password"`
|
||||
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
func (w *WebdavConfig) ToJSON() datatypes.JSON {
|
||||
return datatypes.JSON([]byte(`{"url":"` + w.URL + `","username":"` + w.Username + `","password":"` + w.Password + `","base_path":"` + w.BasePath + `"}`))
|
||||
}
|
||||
|
||||
var Cfg *Config
|
||||
|
||||
func Init() {
|
||||
func Init() error {
|
||||
viper.SetConfigName("config")
|
||||
viper.AddConfigPath(".")
|
||||
viper.AddConfigPath("/etc/saveany/")
|
||||
@@ -125,7 +80,11 @@ func Init() {
|
||||
|
||||
viper.SetDefault("db.path", "data/saveany.db")
|
||||
|
||||
viper.SafeWriteConfigAs("config.toml")
|
||||
if err := viper.SafeWriteConfigAs("config.toml"); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileAlreadyExistsError); !ok {
|
||||
return fmt.Errorf("error saving default config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
fmt.Println("Error reading config file, ", err)
|
||||
@@ -133,17 +92,52 @@ func Init() {
|
||||
}
|
||||
|
||||
Cfg = &Config{}
|
||||
|
||||
if err := viper.Unmarshal(Cfg); err != nil {
|
||||
fmt.Println("Error unmarshalling config file, ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if Cfg.Storage != (storageConfig{}) {
|
||||
fmt.Println("警告: 存储配置已经废弃, 未来版本将会移除.\n请直接使用 Bot 命令添加存储.")
|
||||
|
||||
if Cfg.Telegram.Admins != nil {
|
||||
fmt.Println("警告: 你正在使用旧版 Telegram 管理员配置, 该配置下的用户将可用所有存储.\ntelegram.admins 未来版本将会被废弃, 请参考新的配置文件模板, 使用 users 配置替代.")
|
||||
for _, admin := range Cfg.Telegram.Admins {
|
||||
Cfg.Users = append(Cfg.Users, userConfig{
|
||||
ID: admin,
|
||||
Storages: []string{},
|
||||
Blacklist: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
storagesConfig, err := LoadStorageConfigs(viper.GetViper())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading storage configs: %w", err)
|
||||
}
|
||||
Cfg.Storages = storagesConfig
|
||||
|
||||
if Cfg.DeprecatedStorage != (deprecatedStorageConfig{}) {
|
||||
fmt.Println("\n警告: 你正在使用旧版存储配置, 未来版本将会被废弃.\n请参考新的配置文件模板.")
|
||||
transformDeprecatedStorageConfig()
|
||||
}
|
||||
|
||||
storageNames := make(map[string]struct{})
|
||||
for _, storage := range Cfg.Storages {
|
||||
if _, ok := storageNames[storage.GetName()]; ok {
|
||||
return fmt.Errorf("重复的存储名: %s", storage.GetName())
|
||||
}
|
||||
storageNames[storage.GetName()] = struct{}{}
|
||||
}
|
||||
|
||||
fmt.Printf("已加载 %d 个存储:\n", len(Cfg.Storages))
|
||||
for _, storage := range Cfg.Storages {
|
||||
fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType())
|
||||
}
|
||||
|
||||
if Cfg.Workers < 1 || Cfg.Retry < 1 {
|
||||
fmt.Println("Invalid workers or retry value")
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("workers 和 retry 必须大于 0, 当前值: workers=%d, retry=%d", Cfg.Workers, Cfg.Retry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Set(key string, value any) {
|
||||
|
||||
10
core/core.go
10
core/core.go
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/bot"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/dao"
|
||||
"github.com/krau/SaveAny-Bot/logger"
|
||||
"github.com/krau/SaveAny-Bot/queue"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
@@ -41,11 +40,8 @@ func processPendingTask(task *types.Task) error {
|
||||
if task.StoragePath == "" {
|
||||
task.StoragePath = task.File.FileName
|
||||
}
|
||||
storageModel, err := dao.GetStorageByID(task.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskStorage, err := storage.GetStorageFromModel(*storageModel)
|
||||
|
||||
taskStorage, err := storage.GetStorageByName(task.StorageName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -142,7 +138,7 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
case types.Succeeded:
|
||||
logger.L.Infof("Task succeeded: %s", task.String())
|
||||
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: fmt.Sprintf("文件保存成功\n [%d]: %s", task.StorageID, task.StoragePath),
|
||||
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.StorageName, task.StoragePath),
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
case types.Failed:
|
||||
|
||||
@@ -104,8 +104,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
|
||||
entityBuilder := entity.Builder{}
|
||||
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
|
||||
task.FileName(),
|
||||
// TODO: use storage name instead of ID
|
||||
fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath),
|
||||
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
||||
getSpeed(bytesRead, startTime),
|
||||
getProgressBar(progress, barTotalCount),
|
||||
progress,
|
||||
@@ -115,7 +114,7 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
|
||||
styling.Plain("正在处理下载任务\n文件名: "),
|
||||
styling.Code(task.FileName()),
|
||||
styling.Plain("\n保存路径: "),
|
||||
styling.Code(fmt.Sprintf("[%d]:%s", task.StorageID, task.StoragePath)),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
||||
styling.Plain("\n平均速度: "),
|
||||
styling.Bold(getSpeed(bytesRead, task.StartTime)),
|
||||
styling.Plain("\n当前进度:\n "),
|
||||
|
||||
108
dao/db.go
108
dao/db.go
@@ -36,117 +36,11 @@ func Init() {
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.L.Debug("Database connected")
|
||||
if err := db.AutoMigrate(&types.ReceivedFile{}, &types.User{}, &types.StorageModel{}); err != nil {
|
||||
if err := db.AutoMigrate(&types.ReceivedFile{}, &types.User{}); err != nil {
|
||||
logger.L.Fatal("迁移数据库失败, 如果您从旧版本升级, 建议手动删除数据库文件后重试: ", err)
|
||||
}
|
||||
|
||||
for _, admin := range config.Cfg.Telegram.Admins {
|
||||
CreateUser(int64(admin))
|
||||
}
|
||||
|
||||
logger.L.Infof("Migrating config storages to users")
|
||||
storageCfg := config.Cfg.Storage
|
||||
|
||||
allUsers, err := GetAllUsers()
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to get all users: %v", err)
|
||||
} else {
|
||||
for _, user := range allUsers {
|
||||
found := false
|
||||
for _, admin := range config.Cfg.Telegram.Admins {
|
||||
if user.ChatID == int64(admin) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
logger.L.Debugf("Deleting user %d", user.ChatID)
|
||||
if err := DeleteUser(&user); err != nil {
|
||||
logger.L.Fatalf("Failed to delete user %d: %v", user.ChatID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: refactor this
|
||||
for _, admin := range config.Cfg.Telegram.Admins {
|
||||
user, err := GetUserByChatID(int64(admin))
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to get user by chat ID %d: %v", admin, err)
|
||||
continue
|
||||
}
|
||||
if len(user.Storages) > 0 {
|
||||
logger.L.Debugf("User %d already has storages", admin)
|
||||
continue
|
||||
}
|
||||
if storageCfg.Alist.Enable {
|
||||
alistStorage := &types.StorageModel{
|
||||
Type: string(types.StorageTypeAlist),
|
||||
Active: true,
|
||||
Config: storageCfg.Alist.ToJSON(),
|
||||
}
|
||||
hash := alistStorage.GenHash()
|
||||
alistStorage.Hash = hash
|
||||
if storagedb, err := GetStorageByHash(hash); err == nil {
|
||||
logger.L.Debugf("Alist storage already exists")
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
} else {
|
||||
id, err := CreateStorage(alistStorage)
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to create storage: %v", err)
|
||||
} else {
|
||||
storagedb := &types.StorageModel{}
|
||||
storagedb.ID = id
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
}
|
||||
}
|
||||
}
|
||||
if storageCfg.Local.Enable {
|
||||
localStorage := &types.StorageModel{
|
||||
Type: string(types.StorageTypeLocal),
|
||||
Active: true,
|
||||
Config: storageCfg.Local.ToJSON(),
|
||||
}
|
||||
hash := localStorage.GenHash()
|
||||
localStorage.Hash = hash
|
||||
if storagedb, err := GetStorageByHash(hash); err == nil {
|
||||
logger.L.Debugf("Local storage already exists")
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
} else {
|
||||
id, err := CreateStorage(localStorage)
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to create storage: %v", err)
|
||||
} else {
|
||||
storagedb := &types.StorageModel{}
|
||||
storagedb.ID = id
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
}
|
||||
}
|
||||
}
|
||||
if storageCfg.Webdav.Enable {
|
||||
webdavStorage := &types.StorageModel{
|
||||
Type: string(types.StorageTypeWebdav),
|
||||
Active: true,
|
||||
Config: storageCfg.Webdav.ToJSON(),
|
||||
}
|
||||
hash := webdavStorage.GenHash()
|
||||
webdavStorage.Hash = hash
|
||||
if storagedb, err := GetStorageByHash(hash); err == nil {
|
||||
logger.L.Debugf("Webdav storage already exists")
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
} else {
|
||||
id, err := CreateStorage(webdavStorage)
|
||||
if err != nil {
|
||||
logger.L.Fatalf("Failed to create storage: %v", err)
|
||||
} else {
|
||||
storagedb := &types.StorageModel{}
|
||||
storagedb.ID = id
|
||||
user.Storages = append(user.Storages, storagedb)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := UpdateUser(user); err != nil {
|
||||
logger.L.Fatalf("Failed to update user with storages: %v", err)
|
||||
}
|
||||
}
|
||||
logger.L.Infof("Migration done")
|
||||
}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
func GetActiveStorages() ([]types.StorageModel, error) {
|
||||
var storageModels []types.StorageModel
|
||||
err := db.Where("active = ?", true).Find(&storageModels).Error
|
||||
return storageModels, err
|
||||
}
|
||||
|
||||
func GetStorageByHash(hash string) (*types.StorageModel, error) {
|
||||
var storageModel types.StorageModel
|
||||
err := db.Where("hash = ?", hash).First(&storageModel).Error
|
||||
return &storageModel, err
|
||||
}
|
||||
|
||||
func GetStorageByID(id uint) (*types.StorageModel, error) {
|
||||
var storageModel types.StorageModel
|
||||
err := db.Preload("Users").First(&storageModel, id).Error
|
||||
return &storageModel, err
|
||||
}
|
||||
|
||||
func CreateStorage(model *types.StorageModel) (uint, error) {
|
||||
if model.Hash == "" {
|
||||
model.Hash = model.GenHash()
|
||||
}
|
||||
getModel, err := GetStorageByHash(model.Hash)
|
||||
if err == nil {
|
||||
return getModel.ID, nil
|
||||
}
|
||||
tx := db.Create(model)
|
||||
if tx.Error != nil {
|
||||
return 0, tx.Error
|
||||
}
|
||||
if model.Name == "" {
|
||||
model.Name = fmt.Sprintf("%s - %d", model.Type, model.ID)
|
||||
tx = db.Save(model)
|
||||
if tx.Error != nil {
|
||||
return 0, tx.Error
|
||||
}
|
||||
}
|
||||
return model.ID, nil
|
||||
}
|
||||
13
dao/user.go
13
dao/user.go
@@ -17,18 +17,9 @@ func GetAllUsers() ([]types.User, error) {
|
||||
return users, err
|
||||
}
|
||||
|
||||
// GetUserByUserID gets a user by their telegram user ID
|
||||
//
|
||||
// Return with active storages
|
||||
func GetUserByChatID(chatID int64) (*types.User, error) {
|
||||
var user types.User
|
||||
err := db.Preload("Storages", "active = ?", true).Where("chat_id = ?", chatID).First(&user).Error
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserWithAllStoragesByChatID(chatID int64) (*types.User, error) {
|
||||
var user types.User
|
||||
err := db.Preload("Storages").Where("chat_id = ?", chatID).First(&user).Error
|
||||
err := db.Where("chat_id = ?", chatID).First(&user).Error
|
||||
return &user, err
|
||||
}
|
||||
|
||||
@@ -37,5 +28,5 @@ func UpdateUser(user *types.User) error {
|
||||
}
|
||||
|
||||
func DeleteUser(user *types.User) error {
|
||||
return db.Select("Storages").Delete(user).Error
|
||||
return db.Delete(user).Error
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type Alist struct {
|
||||
token string
|
||||
baseURL string
|
||||
loginInfo *loginRequest
|
||||
config config.AlistConfig
|
||||
config config.AlistStorageConfig
|
||||
}
|
||||
|
||||
var ConfigurableItems = []string{
|
||||
@@ -33,12 +33,16 @@ var ConfigurableItems = []string{
|
||||
"token",
|
||||
}
|
||||
|
||||
func (a *Alist) Init(model types.StorageModel) error {
|
||||
var alistConfig config.AlistConfig
|
||||
if err := json.Unmarshal([]byte(model.Config), &alistConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal alist config: %w", err)
|
||||
func (a *Alist) Init(cfg config.StorageConfig) error {
|
||||
alistConfig, ok := cfg.(*config.AlistStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast alist config")
|
||||
}
|
||||
a.config = alistConfig
|
||||
if err := alistConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
a.config = *alistConfig
|
||||
|
||||
a.baseURL = alistConfig.URL
|
||||
a.client = getHttpClient()
|
||||
if alistConfig.Token != "" {
|
||||
@@ -90,7 +94,7 @@ func (a *Alist) Init(model types.StorageModel) error {
|
||||
}
|
||||
logger.L.Debug("Logged in to Alist")
|
||||
|
||||
go a.refreshToken(alistConfig)
|
||||
go a.refreshToken(*alistConfig)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -98,6 +102,10 @@ func (a *Alist) Type() types.StorageType {
|
||||
return types.StorageTypeAlist
|
||||
}
|
||||
|
||||
func (a *Alist) Name() string {
|
||||
return a.config.Name
|
||||
}
|
||||
|
||||
func (a *Alist) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -48,7 +48,7 @@ func (a *Alist) getToken() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Alist) refreshToken(cfg config.AlistConfig) {
|
||||
func (a *Alist) refreshToken(cfg config.AlistStorageConfig) {
|
||||
tokenExp := cfg.TokenExp
|
||||
if tokenExp <= 0 {
|
||||
logger.L.Warn("Invalid token expiration time, using default value")
|
||||
|
||||
@@ -2,7 +2,6 @@ package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -13,19 +12,22 @@ import (
|
||||
)
|
||||
|
||||
type Local struct {
|
||||
config config.LocalConfig
|
||||
config config.LocalStorageConfig
|
||||
}
|
||||
|
||||
var ConfigurableItems = []string{
|
||||
"base_path",
|
||||
}
|
||||
|
||||
func (l *Local) Init(model types.StorageModel) error {
|
||||
var localConfig config.LocalConfig
|
||||
if err := json.Unmarshal([]byte(model.Config), &localConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal local config: %w", err)
|
||||
func (l *Local) Init(cfg config.StorageConfig) error {
|
||||
localConfig, ok := cfg.(*config.LocalStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast local config")
|
||||
}
|
||||
l.config = localConfig
|
||||
if err := localConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
l.config = *localConfig
|
||||
err := os.MkdirAll(localConfig.BasePath, os.ModePerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local storage directory: %w", err)
|
||||
@@ -37,6 +39,10 @@ func (l *Local) Type() types.StorageType {
|
||||
return types.StorageTypeLocal
|
||||
}
|
||||
|
||||
func (l *Local) Name() string {
|
||||
return l.config.Name
|
||||
}
|
||||
|
||||
func (l *Local) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
absPath, err := filepath.Abs(storagePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,9 +2,9 @@ package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/storage/alist"
|
||||
"github.com/krau/SaveAny-Bot/storage/local"
|
||||
"github.com/krau/SaveAny-Bot/storage/webdav"
|
||||
@@ -12,34 +12,50 @@ import (
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
Init(model types.StorageModel) error
|
||||
Init(cfg config.StorageConfig) error
|
||||
Type() types.StorageType
|
||||
Name() string
|
||||
JoinStoragePath(task types.Task) string
|
||||
Save(cttx context.Context, localFilePath, storagePath string) error
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidStorageID = errors.New("invalid storage ID")
|
||||
)
|
||||
var Storages = make(map[string]Storage)
|
||||
|
||||
var Storages = make(map[uint]Storage)
|
||||
|
||||
// Get storage from model, if it exists, otherwise create and init a new storage
|
||||
func GetStorageFromModel(model types.StorageModel) (Storage, error) {
|
||||
if model.ID == 0 {
|
||||
return nil, ErrInvalidStorageID
|
||||
// GetStorageByName returns storage by name from cache or creates new one
|
||||
func GetStorageByName(name string) (Storage, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("storage name is required")
|
||||
}
|
||||
if storage, ok := Storages[model.ID]; ok {
|
||||
|
||||
storage, ok := Storages[name]
|
||||
if ok {
|
||||
return storage, nil
|
||||
}
|
||||
storage, err := NewStorage(model)
|
||||
cfg := config.Cfg.GetStorageByName(name)
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("storage %s not found", name)
|
||||
}
|
||||
|
||||
storage, err := NewStorage(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Storages[model.ID] = storage
|
||||
Storages[name] = storage
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
func GetUserStorages(chatID int64) []Storage {
|
||||
var storages []Storage
|
||||
for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) {
|
||||
storage, err := GetStorageByName(name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
storages = append(storages, storage)
|
||||
}
|
||||
return storages
|
||||
}
|
||||
|
||||
type StorageConstructor func() Storage
|
||||
|
||||
var storageConstructors = map[string]StorageConstructor{
|
||||
@@ -48,15 +64,15 @@ var storageConstructors = map[string]StorageConstructor{
|
||||
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
|
||||
}
|
||||
|
||||
func NewStorage(model types.StorageModel) (Storage, error) {
|
||||
constructor, ok := storageConstructors[model.Type]
|
||||
func NewStorage(cfg config.StorageConfig) (Storage, error) {
|
||||
constructor, ok := storageConstructors[string(cfg.GetType())]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", model.Type)
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", cfg.GetType())
|
||||
}
|
||||
|
||||
storage := constructor()
|
||||
if err := storage.Init(model); err != nil {
|
||||
return nil, fmt.Errorf("failed to init %s storage: %w", model.Type, err)
|
||||
if err := storage.Init(cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to init %s storage: %w", cfg.GetName(), err)
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
|
||||
@@ -2,7 +2,6 @@ package webdav
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
@@ -15,18 +14,21 @@ import (
|
||||
)
|
||||
|
||||
type Webdav struct {
|
||||
config config.WebdavConfig
|
||||
config config.WebdavStorageConfig
|
||||
client *gowebdav.Client
|
||||
}
|
||||
|
||||
var ConfigurableItems = []string{"url", "username", "password", "base_path"}
|
||||
|
||||
func (w *Webdav) Init(model types.StorageModel) error {
|
||||
var webdavConfig config.WebdavConfig
|
||||
if err := json.Unmarshal([]byte(model.Config), &webdavConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal webdav config: %w", err)
|
||||
func (w *Webdav) Init(cfg config.StorageConfig) error {
|
||||
webdavConfig, ok := cfg.(*config.WebdavStorageConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast webdav config")
|
||||
}
|
||||
w.config = webdavConfig
|
||||
if err := webdavConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
w.config = *webdavConfig
|
||||
client := gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
|
||||
if err := client.Connect(); err != nil {
|
||||
return fmt.Errorf("failed to connect to webdav server: %w", err)
|
||||
@@ -40,6 +42,10 @@ func (w *Webdav) Type() types.StorageType {
|
||||
return types.StorageTypeWebdav
|
||||
}
|
||||
|
||||
func (w *Webdav) Name() string {
|
||||
return w.config.Name
|
||||
}
|
||||
|
||||
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
|
||||
if err := w.client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
|
||||
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -22,33 +18,7 @@ type ReceivedFile struct {
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
ChatID int64 `gorm:"uniqueIndex;not null"`
|
||||
Silent bool
|
||||
DefaultStorageID uint
|
||||
Storages []*StorageModel `gorm:"many2many:user_storages;"`
|
||||
}
|
||||
|
||||
type StorageModel struct {
|
||||
gorm.Model
|
||||
Type string
|
||||
Config datatypes.JSON
|
||||
Active bool
|
||||
Users []*User `gorm:"many2many:user_storages;"`
|
||||
Hash string `gorm:"uniqueIndex"`
|
||||
// just for display
|
||||
Name string `gorm:"not null"`
|
||||
Desc string
|
||||
}
|
||||
|
||||
func (s *StorageModel) GenHash() string {
|
||||
if s.Type == "" || s.Config == nil {
|
||||
return ""
|
||||
}
|
||||
typeBytes := []byte(s.Type)
|
||||
configBytes := s.Config
|
||||
structBytes := append(typeBytes, configBytes...)
|
||||
hash := md5.New()
|
||||
hash.Write(structBytes)
|
||||
hashBytes := hash.Sum(nil)
|
||||
return hex.EncodeToString(hashBytes)
|
||||
ChatID int64 `gorm:"uniqueIndex;not null"`
|
||||
Silent bool
|
||||
DefaultStorage string // Default storage name
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ type Task struct {
|
||||
Error error
|
||||
Status TaskStatus
|
||||
File *File
|
||||
StorageID uint
|
||||
StorageName string
|
||||
StoragePath string
|
||||
StartTime time.Time
|
||||
|
||||
|
||||
Reference in New Issue
Block a user