refactor: complete core features

This commit is contained in:
krau
2024-11-09 09:07:00 +08:00
parent fbdfc04ad8
commit 20e06fbf46
14 changed files with 380 additions and 106 deletions

View File

@@ -14,11 +14,9 @@ import (
"github.com/celestix/gotgproto/dispatcher/handlers"
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/model"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
@@ -41,7 +39,8 @@ const noPermissionText string = `
`
func checkPermission(ctx *ext.Context, update *ext.Update) error {
if !slice.Contain(config.Cfg.Telegram.Admins, update.EffectiveUser().ID) {
userID := update.GetUserChat().GetID()
if !slice.Contain(config.Cfg.Telegram.Admins, userID) {
ctx.Reply(update, noPermissionText, nil)
return dispatcher.EndGroups
}
@@ -49,7 +48,7 @@ func checkPermission(ctx *ext.Context, update *ext.Update) error {
}
func start(ctx *ext.Context, update *ext.Update) error {
if err := dao.CreateUser(update.EffectiveUser().ID); err != nil {
if err := dao.CreateUser(update.GetUserChat().GetID()); err != nil {
logger.L.Errorf("Failed to create user: %s", err)
return dispatcher.EndGroups
}
@@ -74,7 +73,7 @@ func help(ctx *ext.Context, update *ext.Update) error {
}
func silent(ctx *ext.Context, update *ext.Update) error {
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
@@ -116,7 +115,7 @@ func setDefaultStorage(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, "存储位置不存在", nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
@@ -145,7 +144,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.EffectiveUser().ID)
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
@@ -157,7 +156,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
media := update.EffectiveMessage.Media
file, err := common.FileFromMedia(media)
file, err := FileFromMedia(media)
if err != nil {
logger.L.Errorf("Failed to get file from media: %s", err)
ctx.Reply(update, "无法获取文件", nil)
@@ -168,7 +167,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if err := dao.AddReceivedFile(&model.ReceivedFile{
if err := dao.AddReceivedFile(&types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: update.EffectiveChat().GetID(),
@@ -210,7 +209,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
queue.AddTask(types.Task{
Ctx: ctx,
Status: types.Pending,
FileName: file.FileName,
File: file,
Storage: types.StorageType(user.DefaultStorage),
ChatID: update.EffectiveChat().GetID(),
ReplyMessageID: msg.ID,
@@ -234,17 +233,29 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "无法添加到队列",
Message: "查询记录失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, Client, record.ChatID, record.MessageID)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: "获取消息文件失败",
CacheTime: 5,
})
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
Ctx: ctx,
Status: types.Pending,
FileName: record.FileName,
File: file,
Storage: types.StorageType(args[2]),
ChatID: update.EffectiveChat().GetID(),
ChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
MessageID: record.MessageID,
})

View File

@@ -1,15 +1,20 @@
package bot
import (
"context"
"fmt"
"github.com/celestix/gotgproto"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/types"
tgTypes "github.com/celestix/gotgproto/types"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func supportedMediaFilter(m *types.Message) (bool, error) {
func supportedMediaFilter(m *tgTypes.Message) (bool, error) {
if not := m.Media == nil; not {
return false, dispatcher.EndGroups
}
@@ -69,3 +74,69 @@ func getAddTaskMarkup(messageID int) *tg.ReplyInlineMarkup {
},
}
}
func FileFromMedia(media tg.MessageMediaClass) (*types.File, error) {
switch media := media.(type) {
case *tg.MessageMediaDocument:
document, ok := media.Document.AsNotEmpty()
if !ok {
return nil, fmt.Errorf("unexpected type %T", media)
}
var fileName string
for _, attribute := range document.Attributes {
if name, ok := attribute.(*tg.DocumentAttributeFilename); ok {
fileName = name.FileName
break
}
}
return &types.File{
Location: document.AsInputDocumentFileLocation(),
FileSize: document.Size,
FileName: fileName,
MimeType: document.MimeType,
ID: document.ID,
}, nil
}
return nil, fmt.Errorf("unexpected type %T", media)
}
func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64, messageID int) (*types.File, error) {
key := fmt.Sprintf("file:%d:%d", chatID, messageID)
logger.L.Debugf("Getting file: %s", key)
var cachedFile types.File
err := common.Cache.Get(key, &cachedFile)
if err == nil {
return &cachedFile, nil
}
message, err := GetTGMessage(ctx, client, messageID)
if err != nil {
return nil, err
}
file, err := FileFromMedia(message.Media)
if err != nil {
return nil, err
}
if err := common.Cache.Set(key, file, 3600); err != nil {
logger.L.Errorf("Failed to cache file: %s", err)
}
return file, nil
}
func GetTGMessage(ctx context.Context, client *gotgproto.Client, messageID int) (*tg.Message, error) {
logger.L.Debugf("Fetching message: %d", messageID)
res, err := client.API().MessagesGetMessages(ctx, []tg.InputMessageClass{
&tg.InputMessageID{
ID: messageID,
},
})
if err != nil {
return nil, err
}
messages := res.(*tg.MessagesMessages)
msg := messages.Messages[0]
if _, ok := msg.(*tg.Message); !ok {
return nil, fmt.Errorf("unexpected type %T, this file may be deleted", msg)
}
return msg.(*tg.Message), nil
}