feat: (WIP) add storage

Co-authored-by: AHCorn <42889600+AHCorn@users.noreply.github.com>
This commit is contained in:
krau
2025-02-18 22:53:07 +08:00
parent 18cd480264
commit 80696c9661
9 changed files with 411 additions and 54 deletions

232
bot/conversation_storage.go Normal file
View File

@@ -0,0 +1,232 @@
package bot
import (
"encoding/json"
"fmt"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func manageStorageEntry(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByChatID(update.EffectiveUser().GetID())
if err != nil {
logger.L.Errorf("Failed to get user active storages: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户存储失败"), nil)
return dispatcher.EndGroups
}
state, ok := userConversationState[user.ChatID]
if !ok {
state = &ConversationState{}
userConversationState[user.ChatID] = state
}
state.Reset()
state.InConversation = true
state.SetConversationType(ConversationTypeManageStorage)
state.SetData("status", "entry")
storagesMsg := "已添加的存储:"
if len(user.Storages) == 0 {
storagesMsg += " 无"
} else {
for i, storage := range user.Storages {
storagesMsg += fmt.Sprintf("\n%d. %s", i+1, storage.Name)
}
}
storagesMsg += "\n\n请选择操作:"
_, err = ctx.Reply(update, ext.ReplyTextString(storagesMsg), &ext.ReplyOpts{
Markup: &manageStorageKeyboardMarkup,
})
if err != nil {
logger.L.Errorf("Failed to send manage storage message: %s", err)
return dispatcher.EndGroups
}
return dispatcher.EndGroups
}
func handleManageStorageConversation(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
status := state.GetData("status").(string)
switch status {
case "entry":
return manageStorageMenu(ctx, update, state)
case "add_select_type":
return manageStorageAddSelectType(ctx, update, state)
case "selected_add_type":
return manageStorageAddSelectedType(ctx, update, state)
default:
logger.L.Errorf("Unknown manage storage status: %s", status)
}
return dispatcher.EndGroups
}
func manageStorageMenu(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
text := update.EffectiveMessage.Text
switch text {
case manageStorageButtonAdd:
return manageStorageAdd(ctx, update, state)
case manageStorageButtonDelete:
return manageStorageDelete(ctx, update)
case manageStorageButtonEdit:
return manageStorageEdit(ctx, update)
case manageStorageButtonSetDefault:
return manageStorageSetDefault(ctx, update)
default:
logger.L.Errorf("Unknown manage storage button: %s", text)
ctx.Reply(update, ext.ReplyTextString("未知操作"), nil)
return dispatcher.EndGroups
}
}
func manageStorageAdd(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
rows := make([]tg.KeyboardButtonRow, 0)
buttons := make([]tg.KeyboardButtonClass, 0)
for i, storageType := range types.StorageTypes {
buttons = append(buttons, &tg.KeyboardButton{
Text: types.StorageTypeDisplay[storageType],
})
if (i+1)%3 == 0 || i == len(types.StorageTypes)-1 {
rows = append(rows, tg.KeyboardButtonRow{
Buttons: buttons,
})
buttons = make([]tg.KeyboardButtonClass, 0)
}
}
manageStorageAddKeyboardMarkup := tg.ReplyKeyboardMarkup{
Selective: true,
Resize: true,
Rows: rows,
}
state.SetData("status", "add_select_type")
ctx.Reply(update, ext.ReplyTextString("请选择要添加的存储类型"), &ext.ReplyOpts{
Markup: &manageStorageAddKeyboardMarkup,
})
return dispatcher.ContinueGroups
}
func manageStorageAddSelectType(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
text := update.EffectiveMessage.Text
var storageType types.StorageType
for t, display := range types.StorageTypeDisplay {
if display == text {
storageType = t
break
}
}
if storageType == "" {
ctx.Reply(update, ext.ReplyTextString("未知的存储类型"), nil)
return dispatcher.EndGroups
}
state.SetData("status", "selected_add_type")
state.SetData("storage_type", storageType)
return manageStorageAddSelectedType(ctx, update, state)
}
func manageStorageAddSelectedType(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
selectedType := state.GetData("storage_type").(types.StorageType)
configItems := storage.GetStorageConfigurableItems(selectedType)
configIndexData := state.GetData("configindex")
configIndex := 0
if configIndexData == nil {
state.SetData("configindex", configIndex)
} else {
configIndex = configIndexData.(int)
if update.EffectiveMessage.Text != "" {
logger.L.Debugf("config %s: %s", configItems[configIndex-1], update.EffectiveMessage.Text)
state.SetData(configItems[configIndex-1], update.EffectiveMessage.Text)
}
}
if configIndex >= len(configItems) {
// TODO: save storage
state.SetData("status", "add_complete")
logger.L.Infof("Save storage")
return manageStorageSave(ctx, update, state)
}
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("正在配置 %s 存储...\n请提供 %s", types.StorageTypeDisplay[selectedType], configItems[configIndex])), &ext.ReplyOpts{
Markup: &tg.ReplyKeyboardForceReply{
Selective: true,
SingleUse: true,
},
})
state.SetData("configindex", configIndex+1)
return dispatcher.EndGroups
}
func manageStorageSave(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
defer func() {
if r := recover(); r != nil {
logger.L.Errorf("Failed to save storage: %s", r)
ctx.Reply(update, ext.ReplyTextString("存储配置失败"), nil)
}
state.Reset()
}()
storageType := state.GetData("storage_type").(types.StorageType)
config := make(map[string]string)
configItems := storage.GetStorageConfigurableItems(storageType)
for _, item := range configItems {
config[item] = state.GetData(item).(string)
}
configJSON, err := json.Marshal(config)
if err != nil {
logger.L.Errorf("Failed to marshal storage config: %s", err)
ctx.Reply(update, ext.ReplyTextString("存储配置失败"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.EffectiveUser().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storageModel := types.StorageModel{
Type: string(storageType),
Active: true,
Config: configJSON,
}
hash := storageModel.GenHash()
storageModel.Hash = hash
if storagedb, err := dao.GetStorageByHash(hash); err == nil {
logger.L.Debugf("Storage already exists")
user.Storages = append(user.Storages, storagedb)
} else {
if id, err := dao.CreateStorage(&storageModel); err != nil {
logger.L.Errorf("Failed to create storage: %s", err)
ctx.Reply(update, ext.ReplyTextString("存储创建失败"), nil)
return dispatcher.EndGroups
} else {
storagedb := &types.StorageModel{}
storagedb.ID = id
user.Storages = append(user.Storages, storagedb)
}
}
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user with storages: %s", err)
ctx.Reply(update, ext.ReplyTextString("用户更新失败"), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString("存储已添加"), &ext.ReplyOpts{
Markup: &tg.ReplyKeyboardHide{},
})
return dispatcher.EndGroups
}
func manageStorageDelete(ctx *ext.Context, update *ext.Update) error {
return dispatcher.ContinueGroups
}
func manageStorageEdit(ctx *ext.Context, update *ext.Update) error {
return dispatcher.ContinueGroups
}
func manageStorageSetDefault(ctx *ext.Context, update *ext.Update) error {
return dispatcher.ContinueGroups
}

View File

@@ -0,0 +1,84 @@
package bot
import (
"sync"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/logger"
)
type ConversationType string
const (
ConversationTypeManageStorage ConversationType = "manage_storage"
)
type ConversationState struct {
sync.Mutex
conversationType ConversationType
InConversation bool
data map[ConversationType]map[string]interface{}
}
func (c *ConversationState) Reset() {
c.Lock()
defer c.Unlock()
c.InConversation = false
c.conversationType = ""
c.data = make(map[ConversationType]map[string]interface{})
}
func (c *ConversationState) SetConversationType(t ConversationType) {
c.Lock()
defer c.Unlock()
c.conversationType = t
}
func (c *ConversationState) GetData(key string) interface{} {
if c.data == nil || c.data[c.conversationType] == nil {
return nil
}
return c.data[c.conversationType][key]
}
func (c *ConversationState) SetData(key string, value interface{}) {
c.Lock()
defer c.Unlock()
if c.data == nil {
c.data = make(map[ConversationType]map[string]interface{})
}
if c.data[c.conversationType] == nil {
c.data[c.conversationType] = make(map[string]interface{})
}
c.data[c.conversationType][key] = value
}
var userConversationState = make(map[int64]*ConversationState)
func handleConversation(ctx *ext.Context, update *ext.Update) error {
userID := update.EffectiveUser().GetID()
state, ok := userConversationState[userID]
if !ok {
return dispatcher.ContinueGroups
}
if update.EffectiveMessage.Text == "/cancel" {
state.Reset()
ctx.Reply(update, ext.ReplyTextString("已取消"), nil)
return dispatcher.EndGroups
}
if !state.InConversation {
return dispatcher.ContinueGroups
}
return handleConversationState(ctx, update, state)
}
func handleConversationState(ctx *ext.Context, update *ext.Update, state *ConversationState) error {
switch state.conversationType {
case ConversationTypeManageStorage:
return handleManageStorageConversation(ctx, update, state)
default:
logger.L.Errorf("Unknown conversation type: %s", state.conversationType)
}
return dispatcher.EndGroups
}

View File

@@ -26,9 +26,8 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewCommand("start", start))
dispatcher.AddHandler(handlers.NewCommand("help", help))
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
dispatcher.AddHandler(handlers.NewCommand("storage", setDefaultStorage))
dispatcher.AddHandler(handlers.NewCommand("storage", manageStorageEntry))
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
dispatcher.AddHandler(handlers.NewCommand("path", setPath))
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
if err != nil {
logger.L.Panicf("Failed to create regex filter: %s", err)
@@ -36,6 +35,7 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Text, handleConversation))
}
const noPermissionText string = `
@@ -97,21 +97,6 @@ func silent(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user active storages: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户存储失败"), nil)
return dispatcher.EndGroups
}
if len(user.Storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
// TODO: select storage
return dispatcher.EndGroups
}
func saveCmd(ctx *ext.Context, update *ext.Update) error {
res, ok := update.EffectiveMessage.GetReplyTo()
if !ok || res == nil {
@@ -210,11 +195,6 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
})
}
func setPath(ctx *ext.Context, update *ext.Update) error {
// TODO: implement
return dispatcher.EndGroups
}
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage.Message)

View File

@@ -25,6 +25,39 @@ var (
ErrNoStorages = errors.New("no available storage")
)
var (
manageStorageButtonAdd = "添加存储"
manageStorageButtonDelete = "删除存储"
manageStorageButtonEdit = "修改存储"
manageStorageButtonSetDefault = "设置默认存储"
manageStorageKeyboardMarkup = tg.ReplyKeyboardMarkup{
Selective: true,
Resize: true,
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButton{
Text: manageStorageButtonAdd,
},
&tg.KeyboardButton{
Text: manageStorageButtonDelete,
},
&tg.KeyboardButton{
Text: manageStorageButtonEdit,
},
},
},
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButton{
Text: manageStorageButtonSetDefault,
},
},
},
},
}
)
func supportedMediaFilter(m *tg.Message) (bool, error) {
if not := m.Media == nil; not {
return false, dispatcher.EndGroups

View File

@@ -24,6 +24,15 @@ type Alist struct {
config config.AlistConfig
}
var ConfigurableItems = []string{
"url",
"username",
"password",
"base_path",
"token_exp",
"token",
}
func (a *Alist) Init(model types.StorageModel) error {
var alistConfig config.AlistConfig
if err := json.Unmarshal([]byte(model.Config), &alistConfig); err != nil {

View File

@@ -16,6 +16,10 @@ type Local struct {
config config.LocalConfig
}
var ConfigurableItems = []string{
"base_path",
}
func (l *Local) Init(model types.StorageModel) error {
var localConfig config.LocalConfig
if err := json.Unmarshal([]byte(model.Config), &localConfig); err != nil {

View File

@@ -3,6 +3,7 @@ package storage
import (
"context"
"errors"
"fmt"
"github.com/krau/SaveAny-Bot/storage/alist"
"github.com/krau/SaveAny-Bot/storage/local"
@@ -39,26 +40,36 @@ func GetStorageFromModel(model types.StorageModel) (Storage, error) {
return storage, nil
}
func NewStorage(storageModel types.StorageModel) (Storage, error) {
switch storageModel.Type {
case string(types.StorageTypeAlist):
alistStorage := new(alist.Alist)
if err := alistStorage.Init(storageModel); err != nil {
return nil, err
}
return alistStorage, nil
case string(types.StorageTypeLocal):
localStorage := new(local.Local)
if err := localStorage.Init(storageModel); err != nil {
return nil, err
}
return localStorage, nil
case string(types.StorageTypeWebdav):
webdavStorage := new(webdav.Webdav)
if err := webdavStorage.Init(storageModel); err != nil {
return nil, err
}
return webdavStorage, nil
}
return nil, nil
type StorageConstructor func() Storage
var storageConstructors = map[string]StorageConstructor{
string(types.StorageTypeAlist): func() Storage { return new(alist.Alist) },
string(types.StorageTypeLocal): func() Storage { return new(local.Local) },
string(types.StorageTypeWebdav): func() Storage { return new(webdav.Webdav) },
}
func NewStorage(model types.StorageModel) (Storage, error) {
constructor, ok := storageConstructors[model.Type]
if !ok {
return nil, fmt.Errorf("unsupported storage type: %s", model.Type)
}
storage := constructor()
if err := storage.Init(model); err != nil {
return nil, fmt.Errorf("failed to init %s storage: %w", model.Type, err)
}
return storage, nil
}
func GetStorageConfigurableItems(storageType types.StorageType) []string {
switch storageType {
case types.StorageTypeAlist:
return alist.ConfigurableItems
case types.StorageTypeLocal:
return local.ConfigurableItems
case types.StorageTypeWebdav:
return webdav.ConfigurableItems
}
return nil
}

View File

@@ -16,11 +16,10 @@ import (
type Webdav struct {
config config.WebdavConfig
client *gowebdav.Client
}
var (
Client *gowebdav.Client
)
var ConfigurableItems = []string{"url", "username", "password", "base_path"}
func (w *Webdav) Init(model types.StorageModel) error {
var webdavConfig config.WebdavConfig
@@ -28,11 +27,12 @@ func (w *Webdav) Init(model types.StorageModel) error {
return fmt.Errorf("failed to unmarshal webdav config: %w", err)
}
w.config = webdavConfig
Client = gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
if err := Client.Connect(); err != nil {
client := gowebdav.NewClient(webdavConfig.URL, webdavConfig.Username, webdavConfig.Password)
if err := client.Connect(); err != nil {
return fmt.Errorf("failed to connect to webdav server: %w", err)
}
Client.SetTimeout(12 * time.Hour)
client.SetTimeout(12 * time.Hour)
w.client = client
return nil
}
@@ -41,7 +41,7 @@ func (w *Webdav) Type() types.StorageType {
}
func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
if err := Client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
if err := w.client.MkdirAll(path.Dir(storagePath), os.ModePerm); err != nil {
logger.L.Errorf("Failed to create directory %s: %v", path.Dir(storagePath), err)
return ErrFailedToCreateDirectory
}
@@ -52,7 +52,7 @@ func (w *Webdav) Save(ctx context.Context, filePath, storagePath string) error {
}
defer file.Close()
if err := Client.WriteStream(storagePath, file, os.ModePerm); err != nil {
if err := w.client.WriteStream(storagePath, file, os.ModePerm); err != nil {
logger.L.Errorf("Failed to write file %s: %v", storagePath, err)
return ErrFailedToWriteFile
}

View File

@@ -22,13 +22,17 @@ var (
type StorageType string
var (
StorageAll StorageType = "all"
StorageTypeLocal StorageType = "local"
StorageTypeWebdav StorageType = "webdav"
StorageTypeAlist StorageType = "alist"
)
var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav, StorageAll}
var StorageTypes = []StorageType{StorageTypeLocal, StorageTypeAlist, StorageTypeWebdav}
var StorageTypeDisplay = map[StorageType]string{
StorageTypeLocal: "本地磁盘",
StorageTypeWebdav: "WebDAV",
StorageTypeAlist: "Alist",
}
type Task struct {
Ctx context.Context