feat: support save content protect channel message by handle link

This commit is contained in:
krau
2025-02-16 11:38:26 +08:00
parent ec09289d5f
commit db69688722
6 changed files with 198 additions and 117 deletions

97
bot/handle_link.go Normal file
View File

@@ -0,0 +1,97 @@
package bot
import (
"regexp"
"strconv"
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/types"
)
var (
linkRegexString = `t.me/.*/\d+`
linkRegex = regexp.MustCompile(linkRegexString)
)
func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
logger.L.Trace("Got link message")
link := linkRegex.FindString(update.EffectiveMessage.Text)
if link == "" {
return dispatcher.ContinueGroups
}
strSlice := strings.Split(link, "/")
if len(strSlice) < 3 {
return dispatcher.ContinueGroups
}
messageID, err := strconv.Atoi(strSlice[2])
if err != nil {
logger.L.Errorf("Failed to parse message ID: %s", err)
ctx.Reply(update, ext.ReplyTextString("Failed to parse message ID"), nil)
return dispatcher.EndGroups
}
chatUsername := strSlice[1]
linkChat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
logger.L.Errorf("Failed to resolve chat ID: %s", err)
ctx.Reply(update, ext.ReplyTextString("Failed to resolve chat ID"), nil)
return dispatcher.EndGroups
}
user, err := dao.GetUserByUserID(update.GetUserChat().GetID())
if err != nil {
logger.L.Errorf("Failed to get user: %s", err)
return dispatcher.EndGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, linkChat.GetID(), messageID, "")
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
if file.FileName == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取文件名",
ID: replied.ID,
})
return dispatcher.EndGroups
}
receivedFile := &types.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: linkChat.GetID(),
MessageID: messageID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
logger.L.Errorf("Failed to save received file: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件: " + err.Error(),
ID: replied.ID,
})
return dispatcher.EndGroups
}
if !user.Silent {
return ProvideSelectMessage(ctx, update, file, int(linkChat.GetID()), messageID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
FileChatID: linkChat.GetID(),
FileMessageID: messageID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
})
}

View File

@@ -32,6 +32,11 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewCommand("storage", setDefaultStorage))
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
dispatcher.AddHandler(handlers.NewCommand("path", setPath))
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
if err != nil {
logger.L.Panicf("Failed to create regex filter: %s", err)
}
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
}
@@ -146,7 +151,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
msg, err := GetTGMessage(ctx, Client, replyToMsgID)
msg, err := GetTGMessage(ctx, update.EffectiveChat().GetID(), replyToMsgID)
if err != nil {
logger.L.Errorf("Failed to get message: %s", err)
ctx.Reply(update, ext.ReplyTextString("无法获取消息"), nil)
@@ -174,11 +179,11 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
cmdText := update.EffectiveMessage.Text
customFileName := strings.TrimSpace(strings.TrimPrefix(cmdText, "/save"))
file, err := FileFromMessage(ctx, Client, update.EffectiveChat().GetID(), msg.ID, customFileName)
file, err := FileFromMessage(ctx, update.EffectiveChat().GetID(), msg.ID, customFileName)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取文件",
Message: "获取文件失败: " + err.Error(),
ID: replied.ID,
})
return dispatcher.EndGroups
@@ -186,7 +191,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
if file.FileName == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法获取文件名",
Message: "无法获取文件名, 请使用 /save <自定义文件名> 回复此文件",
ID: replied.ID,
})
return dispatcher.EndGroups
@@ -198,6 +203,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
ChatID: update.EffectiveChat().GetID(),
MessageID: replyToMsgID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
}
if err := dao.SaveReceivedFile(receivedFile); err != nil {
@@ -210,53 +216,19 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
}
return dispatcher.EndGroups
}
if !user.Silent {
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
if err := styling.Perform(&entityBuilder,
styling.Plain("文件名: "),
styling.Code(file.FileName),
styling.Plain("\n请选择存储位置"),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ReplyMarkup: getAddTaskMarkup(msg.ID),
ID: replied.ID,
})
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
}
return dispatcher.EndGroups
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), msg.ID, replied.ID)
}
if user.DefaultStorage == "" {
ctx.Reply(update, ext.ReplyTextString("请先使用 /storage 设置默认存储位置"), nil)
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
ChatID: update.EffectiveChat().GetID(),
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: replied.ID,
MessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: msg.ID,
})
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", file.FileName, queue.Len()),
ID: replied.ID,
})
if err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
}
func setPath(ctx *ext.Context, update *ext.Update) error {
@@ -347,6 +319,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
ChatID: update.EffectiveChat().GetID(),
MessageID: update.EffectiveMessage.ID,
ReplyMessageID: msg.ID,
ReplyChatID: update.GetUserChat().GetID(),
}); err != nil {
logger.L.Errorf("Failed to add received file: %s", err)
if _, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
@@ -359,53 +332,18 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
}
if !user.Silent {
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
if err := styling.Perform(&entityBuilder,
styling.Plain("文件名: "),
styling.Code(file.FileName),
styling.Plain("\n请选择存储位置"),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
_, err = ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ReplyMarkup: getAddTaskMarkup(update.EffectiveMessage.ID),
ID: msg.ID,
})
if err != nil {
logger.L.Errorf("Failed to edit message: %s", err)
}
return dispatcher.EndGroups
return ProvideSelectMessage(ctx, update, file, int(update.EffectiveChat().GetID()), update.EffectiveMessage.ID, msg.ID)
}
if user.DefaultStorage == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "请先使用 /storage 设置默认存储位置",
ID: msg.ID,
})
return dispatcher.EndGroups
}
queue.AddTask(types.Task{
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(user.DefaultStorage),
ChatID: update.EffectiveChat().GetID(),
FileChatID: update.EffectiveChat().GetID(),
ReplyMessageID: msg.ID,
MessageID: update.EffectiveMessage.ID,
ReplyChatID: update.GetUserChat().GetID(),
FileMessageID: update.EffectiveMessage.ID,
})
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", file.FileName, queue.Len()),
ID: msg.ID,
})
return dispatcher.EndGroups
}
func AddToQueue(ctx *ext.Context, update *ext.Update) error {
@@ -419,9 +357,11 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
args := strings.Split(string(update.CallbackQuery.Data), " ")
messageID, _ := strconv.Atoi(args[1])
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", update.EffectiveChat().GetID(), messageID, args[2])
record, err := dao.GetReceivedFileByChatAndMessageID(update.EffectiveChat().GetID(), messageID)
chatID, _ := strconv.Atoi(args[1])
messageID, _ := strconv.Atoi(args[2])
storageName := args[3]
logger.L.Tracef("Got add to queue: chatID: %d, messageID: %d, storage: %s", chatID, messageID, storageName)
record, err := dao.GetReceivedFileByChatAndMessageID(int64(chatID), messageID)
if err != nil {
logger.L.Errorf("Failed to get received file: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
@@ -439,7 +379,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
}
}
file, err := FileFromMessage(ctx, Client, record.ChatID, record.MessageID, record.FileName)
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
if err != nil {
logger.L.Errorf("Failed to get file from message: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
@@ -455,10 +395,11 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
Ctx: ctx,
Status: types.Pending,
File: file,
Storage: types.StorageType(args[2]),
ChatID: record.ChatID,
Storage: types.StorageType(storageName),
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
MessageID: record.MessageID,
FileMessageID: record.MessageID,
ReplyChatID: record.ReplyChatID,
})
entityBuilder := entity.Builder{}

View File

@@ -1,16 +1,18 @@
package bot
import (
"context"
"errors"
"fmt"
"time"
"github.com/celestix/gotgproto"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/logger"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
@@ -44,12 +46,12 @@ var StorageDisplayNames = map[string]string{
"webdav": "WebDAV",
}
func getAddTaskMarkup(messageID int) *tg.ReplyInlineMarkup {
func getAddTaskMarkup(chatID, messageID int) *tg.ReplyInlineMarkup {
storageButtons := make([]tg.KeyboardButtonClass, 0)
for _, name := range storage.StorageKeys {
storageButtons = append(storageButtons, &tg.KeyboardButtonCallback{
Text: StorageDisplayNames[string(name)],
Data: []byte(fmt.Sprintf("add %d %s", messageID, name)),
Data: []byte(fmt.Sprintf("add %d %d %s", chatID, messageID, name)),
})
}
@@ -74,7 +76,7 @@ func getAddTaskMarkup(messageID int) *tg.ReplyInlineMarkup {
Buttons: []tg.KeyboardButtonClass{
&tg.KeyboardButtonCallback{
Text: "全部",
Data: []byte(fmt.Sprintf("add %d all", messageID)),
Data: []byte(fmt.Sprintf("add %d %d all", chatID, messageID)),
},
},
},
@@ -144,7 +146,7 @@ func FileFromMedia(media tg.MessageMediaClass, customFileName string) (*types.Fi
return nil, fmt.Errorf("unexpected type %T", media)
}
func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64, messageID int, customFileName string) (*types.File, error) {
func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileName string) (*types.File, error) {
key := fmt.Sprintf("file:%d:%d", chatID, messageID)
logger.L.Debugf("Getting file: %s", key)
var cachedFile types.File
@@ -152,8 +154,7 @@ func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64
if err == nil {
return &cachedFile, nil
}
message, err := GetTGMessage(ctx, client, messageID)
message, err := GetTGMessage(ctx, chatID, messageID)
if err != nil {
return nil, err
}
@@ -167,20 +168,60 @@ func FileFromMessage(ctx context.Context, client *gotgproto.Client, chatID int64
return file, nil
}
func GetTGMessage(ctx context.Context, client *gotgproto.Client, messageID int) (*tg.Message, error) {
func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) {
logger.L.Debugf("Fetching message: %d", messageID)
res, err := client.API().MessagesGetMessages(ctx, []tg.InputMessageClass{
&tg.InputMessageID{
ID: messageID,
},
})
messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}})
if err != nil {
return nil, err
}
messages := res.(*tg.MessagesMessages)
msg := messages.Messages[0]
if _, ok := msg.(*tg.Message); !ok {
return nil, fmt.Errorf("unexpected type %T, this file may be deleted", msg)
if len(messages) == 0 {
return nil, errors.New("no messages found")
}
return msg.(*tg.Message), nil
msg := messages[0]
tgMessage, ok := msg.(*tg.Message)
if !ok {
return nil, fmt.Errorf("unexpected message type: %T", msg)
}
return tgMessage, nil
}
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int, fileMsgID, toEditMsgID int) error {
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
if err := styling.Perform(&entityBuilder,
styling.Plain("文件名: "),
styling.Code(file.FileName),
styling.Plain("\n请选择存储位置"),
); err != nil {
logger.L.Errorf("Failed to build entity: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
_, err := ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ReplyMarkup: getAddTaskMarkup(chatID, fileMsgID),
ID: toEditMsgID,
})
if err != nil {
logger.L.Errorf("Failed to reply: %s", err)
}
return dispatcher.EndGroups
}
func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *types.User, task *types.Task) error {
if user.DefaultStorage == "" {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "请先使用 /storage 设置默认存储位置",
ID: task.ReplyMessageID,
})
return dispatcher.EndGroups
}
queue.AddTask(*task)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()),
ID: task.ReplyMessageID,
})
return dispatcher.EndGroups
}

View File

@@ -58,7 +58,7 @@ func processPendingTask(task *types.Task) error {
return
}
text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
@@ -66,7 +66,7 @@ func processPendingTask(task *types.Task) error {
}
text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0)
ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
@@ -92,7 +92,7 @@ func processPendingTask(task *types.Task) error {
defer cleanCacheFile(cacheDestPath)
logger.L.Infof("Downloaded file: %s", cacheDestPath)
ctx.EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("下载完成: %s\n正在转存文件...", task.FileName()),
ID: task.ReplyMessageID,
})
@@ -124,13 +124,13 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
queue.AddTask(task)
case types.Succeeded:
logger.L.Infof("Task succeeded: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("文件保存成功\n [%s]: %s", task.Storage, task.StoragePath),
ID: task.ReplyMessageID,
})
case types.Failed:
logger.L.Errorf("Task failed: %s", task.String())
task.Ctx.(*ext.Context).EditMessage(task.ChatID, &tg.MessagesEditMessageRequest{
task.Ctx.(*ext.Context).EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: "文件保存失败\n" + task.Error.Error(),
ID: task.ReplyMessageID,
})

View File

@@ -10,6 +10,7 @@ type ReceivedFile struct {
ChatID int64 `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
MessageID int `gorm:"uniqueIndex:idx_chat_id_message_id;not null"`
ReplyMessageID int
ReplyChatID int64
FileName string
}

View File

@@ -37,13 +37,14 @@ type Task struct {
StoragePath string
StartTime time.Time
MessageID int
ChatID int64
FileMessageID int
FileChatID int64
ReplyMessageID int
ReplyChatID int64
}
func (t Task) String() string {
return fmt.Sprintf("[%d:%d]:%s", t.ChatID, t.MessageID, t.File.FileName)
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
func (t Task) FileName() string {