refactor: migrate to gotd (wip)

This commit is contained in:
krau
2024-11-08 23:00:57 +08:00
parent 32bd391129
commit fbdfc04ad8
14 changed files with 465 additions and 387 deletions

View File

@@ -1,59 +1,54 @@
package bot
import (
"context"
"os"
"time"
"github.com/amarnathcjd/gogram/telegram"
"github.com/celestix/gotgproto"
"github.com/celestix/gotgproto/sessionMaker"
"github.com/glebarez/sqlite"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/logger"
)
var (
Client *telegram.Client
)
var Client *gotgproto.Client
func Init() {
logger.L.Debug("Initializing bot...")
var err error
Client, err = telegram.NewClient(telegram.ClientConfig{
AppID: config.Cfg.Telegram.AppID,
AppHash: config.Cfg.Telegram.AppHash,
LogLevel: telegram.LogInfo,
logger.L.Info("Initializing client...")
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resultChan := make(chan struct {
client *gotgproto.Client
err error
})
if err != nil {
logger.L.Fatal("Failed to create telegram client: ", err)
go func() {
client, err := gotgproto.NewClient(int(config.Cfg.Telegram.AppID), config.Cfg.Telegram.AppHash, gotgproto.ClientTypeBot(config.Cfg.Telegram.Token),
&gotgproto.ClientOpts{
Session: sessionMaker.SqlSession(sqlite.Open("data/session.db")),
DisableCopyright: true,
Middlewares: FloodWaitMiddleware(),
},
)
resultChan <- struct {
client *gotgproto.Client
err error
}{client, err}
}()
select {
case <-ctx.Done():
logger.L.Fatal("Failed to initialize client")
os.Exit(1)
case result := <-resultChan:
if result.err != nil {
logger.L.Fatalf("Failed to initialize client: %s", result.err)
os.Exit(1)
}
Client = result.client
RegisterHandlers(Client.Dispatcher)
logger.L.Info("Client initialized")
}
if err := Client.LoginBot(config.Cfg.Telegram.Token); err != nil {
logger.L.Fatal("Failed to login bot: ", err)
os.Exit(1)
}
logger.L.Info("Bot logged in")
_, err = Client.BotsSetBotCommands(&telegram.BotCommandScopeDefault{}, "", []*telegram.BotCommand{
{Command: "start", Description: "开始使用"},
{Command: "help", Description: "显示帮助"},
{Command: "silent", Description: "静默模式"},
{Command: "storage", Description: "设置默认存储位置"},
{Command: "save", Description: "保存所回复文件"},
})
if err != nil {
logger.L.Errorf("Failed to set bot commands: ", err)
}
logger.L.Info("Bot initialized")
}
func Run() {
if Client == nil {
Init()
}
Client.On("command:start", Start, telegram.FilterPrivate, telegram.FilterChats(config.Cfg.Telegram.Admins...))
Client.On("command:help", Help, telegram.FilterPrivate, telegram.FilterChats(config.Cfg.Telegram.Admins...))
Client.On("command:silent", ChangeSilentMode, telegram.FilterPrivate, telegram.FilterChats(config.Cfg.Telegram.Admins...))
Client.On("command:storage", SetDefaultStorage, telegram.FilterPrivate, telegram.FilterChats(config.Cfg.Telegram.Admins...))
Client.On("command:save", SaveCmd, telegram.FilterPrivate, telegram.FilterChats(config.Cfg.Telegram.Admins...))
Client.On(telegram.OnMessage, HandleFileMessage, telegram.FilterPrivate, telegram.FilterChats(config.Cfg.Telegram.Admins...), telegram.FilterMedia)
Client.On("callback:add", AddToQueue)
Client.Idle()
}

View File

@@ -1,33 +1,62 @@
package bot
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/amarnathcjd/gogram/telegram"
"github.com/duke-git/lancet/v2/slice"
"github.com/gookit/goutil/maputil"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/dispatcher/handlers"
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/model"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
"github.com/mymmrac/telego/telegoutil"
)
func Start(message *telegram.NewMessage) error {
if err := dao.CreateUser(message.ChatID()); err != nil {
logger.L.Errorf("Failed to create user: %s", err)
return err
}
return Help(message)
func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewAnyUpdate(checkPermission))
dispatcher.AddHandler(handlers.NewCommand("start", start))
dispatcher.AddHandler(handlers.NewCommand("help", help))
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
dispatcher.AddHandler(handlers.NewCommand("storage", setDefaultStorage))
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
}
func Help(message *telegram.NewMessage) error {
helpText := `
const noPermissionText string = `
本 Bot 仅限个人使用.
您可以部署自己的实例: https://github.com/krau/SaveAny-Bot
`
func checkPermission(ctx *ext.Context, update *ext.Update) error {
if !slice.Contain(config.Cfg.Telegram.Admins, update.EffectiveUser().ID) {
ctx.Reply(update, noPermissionText, nil)
return dispatcher.EndGroups
}
return dispatcher.ContinueGroups
}
func start(ctx *ext.Context, update *ext.Update) error {
if err := dao.CreateUser(update.EffectiveUser().ID); err != nil {
logger.L.Errorf("Failed to create user: %s", err)
return dispatcher.EndGroups
}
return help(ctx, update)
}
const helpText string = `
SaveAny Bot - 转存你的 Telegram 文件
命令:
/start - 开始使用
@@ -37,235 +66,191 @@ SaveAny Bot - 转存你的 Telegram 文件
/save - 保存文件
静默模式: 开启后 Bot 直接保存到收到的文件到默认位置, 不再询问
`
if _, err := message.Reply(helpText); err != nil {
logger.L.Errorf("Failed to send help message: %s", err)
return err
}
return nil
`
func help(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, helpText, nil)
return dispatcher.EndGroups
}
func ChangeSilentMode(message *telegram.NewMessage) error {
user, err := dao.GetUserByUserID(message.ChatID())
func silent(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
if err != nil {
logger.L.Error(err)
return err
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
user.Silent = !user.Silent
err = dao.UpdateUser(user)
if err != nil {
logger.L.Error(err)
return err
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user: %s", err)
return dispatcher.EndGroups
}
if _, err := message.Reply(fmt.Sprintf("已%s静默模式", map[bool]string{true: "开启", false: "关闭"}[user.Silent])); err != nil {
return err
}
return nil
ctx.Reply(update, fmt.Sprintf("已%s静默模式", func() string {
if user.Silent {
return "开启"
}
return "关闭"
}()), nil)
return dispatcher.EndGroups
}
func SetDefaultStorage(message *telegram.NewMessage) error {
func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
if len(storage.Storages) == 0 {
message.Reply("当前无可用存储端, 请检查配置.")
return nil
ctx.Reply(update, "未配置存储", nil)
return dispatcher.EndGroups
}
_, _, args := telegoutil.ParseCommand(message.Text())
availableStorages := maputil.Keys(storage.Storages)
if len(args) == 0 {
text := "请提供存储位置名称, 可用项:"
for _, name := range availableStorages {
text += fmt.Sprintf("\n`%s`", name)
args := strings.Split(update.EffectiveMessage.Text, " ")
avaliableStorages := maputil.Keys(storage.Storages)
if len(args) < 2 {
text := []styling.StyledTextOption{
styling.Plain("请提供存储位置名称, 可用项:"),
}
text += fmt.Sprintf("\n`all`")
message.Reply(text, telegram.SendOptions{ParseMode: telegram.MarkDown})
return nil
for _, name := range avaliableStorages {
text = append(text, styling.Plain("\n"))
text = append(text, styling.Code(name))
}
ctx.Reply(update, text, nil)
return dispatcher.EndGroups
}
storageName := args[0]
if !slice.Contain(availableStorages, storageName) {
message.Reply("参数错误")
return nil
storageName := args[1]
if !slice.Contain(avaliableStorages, storageName) {
ctx.Reply(update, "存储位置不存在", nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(message.ChatID())
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
if err != nil {
logger.L.Error(err)
return err
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
user.DefaultStorage = storageName
err = dao.UpdateUser(user)
if err != nil {
logger.L.Error(err)
return err
if err := dao.UpdateUser(user); err != nil {
logger.L.Errorf("Failed to update user: %s", err)
return dispatcher.EndGroups
}
if _, err := message.Reply(fmt.Sprintf("已设置默认存储位置为: %s", storageName)); err != nil {
return err
}
return nil
ctx.Reply(update, fmt.Sprintf("已设置默认存储位置为 %s", storageName), nil)
return dispatcher.EndGroups
}
func SaveCmd(message *telegram.NewMessage) error {
targetMessage, err := message.GetReplyMessage()
if err != nil {
message.Reply("请回复要保存的文件")
return nil
}
if !targetMessage.IsMedia() {
message.Reply("回复的消息不包含文件")
return nil
}
func saveCmd(ctx *ext.Context, update *ext.Update) error {
// TODO: Implement save command
return dispatcher.EndGroups
}
msg, err := targetMessage.Reply("正在获取文件信息...")
func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Debug("Got media: ", update.EffectiveMessage.Media.TypeName())
supported, err := supportedMediaFilter(update.EffectiveMessage)
if err != nil {
logger.L.Error(err)
message.Reply("获取文件信息失败")
return err
}
_, _, _, fileName, err := telegram.GetFileLocation(targetMessage.Media())
if err != nil {
logger.L.Error(err)
targetMessage.Reply("获取文件信息失败")
return err
if !supported {
return dispatcher.EndGroups
}
if fileName == "" {
logger.L.Error("Empty file name")
targetMessage.Reply("文件名为空")
return nil
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
msg, err := ctx.Reply(update, "正在获取文件信息...", nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
media := update.EffectiveMessage.Media
file, err := common.FileFromMedia(media)
if err != nil {
logger.L.Errorf("Failed to get file from media: %s", err)
ctx.Reply(update, "无法获取文件", nil)
return dispatcher.EndGroups
}
if file.FileName == "" {
ctx.Reply(update, "无法获取文件名", nil)
return dispatcher.EndGroups
}
if err := dao.AddReceivedFile(&model.ReceivedFile{
Processing: false,
FileName: fileName,
ChatID: targetMessage.ChatID(),
MessageID: targetMessage.Message.ID,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: update.EffectiveMessage.ID,
ReplyMessageID: msg.ID,
}); err != nil {
logger.L.Error(err)
msg.Edit("保存文件信息失败")
return err
}
logger.L.Errorf("Failed to add received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件",
ID: msg.ID,
}); err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
user, err := dao.GetUserByUserID(message.ChatID())
if err != nil {
logger.L.Error(err)
msg.Edit("获取用户信息失败")
return err
return dispatcher.EndGroups
}
if !user.Silent {
msg.Edit("请选择要保存的位置:", telegram.SendOptions{
ReplyMarkup: AddTaskReplyMarkup(targetMessage.Message.ID),
text := "请选择存储位置"
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
ReplyMarkup: getAddTaskMarkup(update.EffectiveMessage.ID),
ID: msg.ID,
})
return nil
if err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}
if user.DefaultStorage == "" {
msg.Edit("请先使用 /storage 命令设置默认存储位置, 或者关闭静默模式")
return nil
}
queue.AddTask(types.Task{
Ctx: context.TODO(),
Status: types.Pending,
FileName: fileName,
Storage: types.StorageType(user.DefaultStorage),
ChatID: targetMessage.ChatID(),
MessageID: targetMessage.Message.ID,
ReplyMessageID: msg.ID,
})
msg.Edit(fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", fileName, queue.Len()))
return nil
}
func HandleFileMessage(message *telegram.NewMessage) error {
if !message.IsMedia() {
return nil
}
user, err := dao.GetUserByUserID(message.ChatID())
if err != nil {
logger.L.Error(err)
return nil
}
msg, err := message.Reply("正在获取文件信息...")
if err != nil {
logger.L.Error(err)
return err
}
_, _, _, fileName, err := telegram.GetFileLocation(message.Media())
if err != nil {
logger.L.Error(err)
message.Reply("获取文件信息失败")
return err
}
if fileName == "" {
logger.L.Error("Empty file name")
message.Reply("文件名为空")
return nil
}
if err := dao.AddReceivedFile(&model.ReceivedFile{
Processing: false,
FileName: fileName,
ChatID: message.ChatID(),
MessageID: message.Message.ID,
ReplyMessageID: msg.ID,
}); err != nil {
logger.L.Error(err)
msg.Edit("保存文件信息失败")
return err
}
if !user.Silent {
msg.Edit("请选择要保存的位置:", telegram.SendOptions{
ReplyMarkup: AddTaskReplyMarkup(message.Message.ID),
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "请先使用 /storage 设置默认存储位置",
ID: msg.ID,
})
return nil
}
if user.DefaultStorage == "" {
msg.Edit("请先使用 /storage 命令设置默认存储位置, 或者关闭静默模式")
return nil
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
Ctx: context.TODO(),
Ctx: ctx,
Status: types.Pending,
FileName: fileName,
FileName: file.FileName,
Storage: types.StorageType(user.DefaultStorage),
ChatID: message.ChatID(),
MessageID: message.Message.ID,
ChatID: update.EffectiveChat().GetID(),
ReplyMessageID: msg.ID,
MessageID: update.EffectiveMessage.ID,
})
msg.Edit(fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", fileName, queue.Len()))
return nil
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", file.FileName, queue.Len()),
ID: msg.ID,
})
return dispatcher.EndGroups
}
func AddToQueue(query *telegram.CallbackQuery) error {
args := strings.Split(query.DataString(), " ")
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
args := strings.Split(string(update.CallbackQuery.Data), " ")
messageID, _ := strconv.Atoi(args[1])
logger.L.Debug(query.ChatID, messageID)
receivedFile, err := dao.GetReceivedFileByChatAndMessageID(query.ChatID, int32(messageID))
logger.L.Debugf("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2])
record, err := dao.GetReceivedFileByChatAndMessageID(update.EffectiveChat().GetID(), messageID)
if err != nil {
logger.L.Error(err)
query.Answer("获取文件信息失败", &telegram.CallbackOptions{
logger.L.Errorf("Failed to get received file: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "无法添加到队列",
CacheTime: 5,
})
return err
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
Ctx: context.TODO(),
Ctx: ctx,
Status: types.Pending,
FileName: receivedFile.FileName,
FileName: record.FileName,
Storage: types.StorageType(args[2]),
ChatID: receivedFile.ChatID,
MessageID: receivedFile.MessageID,
ReplyMessageID: receivedFile.ReplyMessageID,
ChatID: update.EffectiveChat().GetID(),
ReplyMessageID: record.ReplyMessageID,
MessageID: record.MessageID,
})
query.Edit(fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", receivedFile.FileName, queue.Len()))
return nil
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", record.FileName, queue.Len()),
ID: record.ReplyMessageID,
})
return dispatcher.EndGroups
}

19
bot/middlewares.go Normal file
View File

@@ -0,0 +1,19 @@
package bot
import (
"time"
"github.com/gotd/contrib/middleware/floodwait"
"github.com/gotd/contrib/middleware/ratelimit"
"github.com/gotd/td/telegram"
"golang.org/x/time/rate"
)
func FloodWaitMiddleware() []telegram.Middleware {
waiter := floodwait.NewSimpleWaiter().WithMaxRetries(5)
ratelimiter := ratelimit.New(rate.Every(time.Millisecond*100), 5)
return []telegram.Middleware{
waiter,
ratelimiter,
}
}

View File

@@ -2,12 +2,29 @@ package bot
import (
"fmt"
"regexp"
"github.com/amarnathcjd/gogram/telegram"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/types"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/storage"
)
func supportedMediaFilter(m *types.Message) (bool, error) {
if not := m.Media == nil; not {
return false, dispatcher.EndGroups
}
switch m.Media.(type) {
case *tg.MessageMediaDocument:
return true, nil
case *tg.MessageMediaWebPage:
return false, dispatcher.EndGroups
case tg.MessageMediaClass:
return false, dispatcher.EndGroups
default:
return false, nil
}
}
var StorageDisplayNames = map[string]string{
"all": "全部",
"local": "服务器磁盘",
@@ -15,48 +32,40 @@ var StorageDisplayNames = map[string]string{
"webdav": "WebDAV",
}
func AddTaskReplyMarkup(messageID int32) telegram.ReplyMarkup {
// TODO: sort storage buttons
storageButtons := make([]telegram.KeyboardButton, 0)
for name := range storage.Storages {
storageButtons = append(storageButtons, &telegram.KeyboardButtonCallback{
func getAddTaskMarkup(messageID int) *tg.ReplyInlineMarkup {
storageButtons := make([]tg.KeyboardButtonClass, 0)
for _, name := range storage.StorageKeys {
storageButtons = append(storageButtons, &tg.KeyboardButtonCallback{
Text: StorageDisplayNames[string(name)],
Data: []byte(fmt.Sprintf("add %d %s", messageID, name)),
})
}
if len(storageButtons) > 1 {
return &telegram.ReplyInlineMarkup{
Rows: []*telegram.KeyboardButtonRow{
if len(storageButtons) < 1 {
return nil
}
if len(storageButtons) == 1 {
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: storageButtons,
},
{
Buttons: []telegram.KeyboardButton{
&telegram.KeyboardButtonCallback{
Text: "全部",
Data: []byte(fmt.Sprintf("add %d all", messageID)),
},
},
}
}
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: storageButtons,
},
{
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButtonCallback{
Text: "全部",
Data: []byte(fmt.Sprintf("add %d all", messageID)),
},
},
},
}
},
}
if len(storageButtons) == 1 {
return &telegram.ReplyInlineMarkup{
Rows: []*telegram.KeyboardButtonRow{
{
Buttons: storageButtons,
},
},
}
}
return nil
}
var markdownRe = regexp.MustCompile("([" + regexp.QuoteMeta(`\_*[]()~`+"`"+`>#+-=|{}.!`) + "])")
func EscapeMarkdown(text string) string {
return markdownRe.ReplaceAllString(text, "\\$1")
}