From 19535d0438104166b3cafd35ca32a6db5a02082c Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:17:27 +0800 Subject: [PATCH] feat: parse media group, wip --- bot/handle_link.go | 147 ++++++++++++++++++++++++++++++++------------- bot/utils.go | 70 +++++++++++++++++++++ dao/file.go | 13 ++++ 3 files changed, 187 insertions(+), 43 deletions(-) diff --git a/bot/handle_link.go b/bot/handle_link.go index 536ed9a..a148d96 100644 --- a/bot/handle_link.go +++ b/bot/handle_link.go @@ -1,7 +1,9 @@ package bot import ( + "errors" "fmt" + "net/url" "regexp" "strconv" "strings" @@ -16,39 +18,99 @@ import ( ) var ( - linkRegexString = `t.me/.*/\d+` + linkRegexString = `https?://t\.me/(?:c/\d+|[a-zA-Z0-9_]+)/\d+(?:\?[^\s]*)?` linkRegex = regexp.MustCompile(linkRegexString) ) -func parseLink(ctx *ext.Context, link string) (chatID int64, messageID int, err error) { - strSlice := strings.Split(link, "/") - if len(strSlice) < 3 { - return 0, 0, fmt.Errorf("链接格式错误: %s", link) - } - messageID, err = strconv.Atoi(strSlice[len(strSlice)-1]) +type parseResult struct { + ChatID int64 + MessageID int + Files []*types.File + UserClient bool +} + +func parseLink(ctx *ext.Context, link string) (*parseResult, error) { + u, err := url.Parse(link) if err != nil { - return 0, 0, fmt.Errorf("无法解析消息 ID: %s", err) + return nil, fmt.Errorf("无法解析链接: %s", err) } + strSlice := strings.Split(u.Path, "/") + if len(strSlice) < 3 { + return nil, fmt.Errorf("链接格式错误: %s", link) + } + messageID, err := strconv.Atoi(strSlice[len(strSlice)-1]) + if err != nil { + return nil, fmt.Errorf("无法解析消息 ID: %s", err) + } + var chatID int64 if len(strSlice) == 3 { chatUsername := strSlice[1] - linkChat, err := ctx.ResolveUsername(chatUsername) - if err != nil { - return 0, 0, fmt.Errorf("解析用户名失败: %s", err) + peer := ctx.PeerStorage.GetPeerByUsername(chatUsername) + if peer != nil { + chatID = peer.ID + } else { + linkChat, err := ctx.ResolveUsername(chatUsername) + if err != nil { + return nil, fmt.Errorf("解析用户名失败: %s", err) + } + if linkChat == nil { + return nil, fmt.Errorf("找不到该聊天: %s", chatUsername) + } + chatID = linkChat.GetID() } - if linkChat == nil { - return 0, 0, fmt.Errorf("找不到该聊天: %s", chatUsername) - } - chatID = linkChat.GetID() } else if len(strSlice) == 4 { chatIDInt, err := strconv.Atoi(strSlice[2]) if err != nil { - return 0, 0, fmt.Errorf("无法解析 Chat ID: %s", err) + return nil, fmt.Errorf("无法解析 Chat ID: %s", err) } chatID = int64(chatIDInt) } else { - return 0, 0, fmt.Errorf("无效的链接: %s", link) + return nil, errors.New("链接格式不正确,无法解析 Chat ID") } - return chatID, messageID, nil + if chatID == 0 || messageID == 0 { + return nil, fmt.Errorf("链接中缺少 Chat ID 或 Message ID: %s", link) + } + msg, err := tryFetchMessage(ctx, chatID, messageID) + if err != nil { + return nil, fmt.Errorf("获取消息失败: %s", err) + } + mediaGroup, isGroup := msg.GetGroupedID() + if u.Query().Has("single") || !isGroup || (mediaGroup == 0) || userclient.UC == nil { + file, useUserClient, err := tryFetchFileFromMessage(ctx, chatID, messageID, "") + if err != nil { + return nil, fmt.Errorf("获取文件失败: %s", err) + } + if file.FileName == "" { + file.FileName = GenFileNameFromMessage(*msg, file) + } + return &parseResult{ + ChatID: chatID, + MessageID: messageID, + Files: []*types.File{file}, + UserClient: useUserClient, + }, nil + } + groupMessages, isUserClient, err := tryGetMediaGroup(chatID, messageID, mediaGroup) + if err != nil { + return nil, fmt.Errorf("获取媒体组消息失败: %s", err) + } + var files []*types.File + for _, groupMsg := range groupMessages { + file, err := FileFromMedia(groupMsg.Media, "") + if err != nil { + return nil, fmt.Errorf("获取媒体文件失败: %s", err) + } + if file.FileName == "" { + file.FileName = GenFileNameFromMessage(*groupMsg, file) + } + files = append(files, file) + } + return &parseResult{ + ChatID: chatID, + MessageID: messageID, + Files: files, + UserClient: isUserClient, + }, nil } // use passed ctx client to fetch file from message, @@ -79,6 +141,18 @@ func tryFetchFileFromMessage(ctx *ext.Context, chatID int64, messageID int, file return nil, false, err } +func tryGetMediaGroup(chatID int64, messageID int, mediaGroupID int64) ([]*tg.Message, bool, error) { + if userclient.UC != nil { + uctx := userclient.GetCtx() + messages, err := GetMediaGroup(uctx, chatID, messageID, mediaGroupID) + if err != nil { + return nil, true, fmt.Errorf("failed to get media group from userbot: %w", err) + } + return messages, true, nil + } + return nil, false, errors.New("userclient is not available, cannot fetch media group") +} + func tryFetchMessage(ctx *ext.Context, chatID int64, messageID int) (*tg.Message, error) { return GetTGMessage(ctx, chatID, messageID) } @@ -89,10 +163,10 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { if link == "" { return dispatcher.ContinueGroups } - linkChatID, messageID, err := parseLink(ctx, link) + result, err := parseLink(ctx, link) if err != nil { common.Log.Errorf("解析链接失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("解析链接失败: "+err.Error()), nil) + ctx.Reply(update, ext.ReplyTextString("解析链接失败"), nil) return dispatcher.EndGroups } @@ -109,29 +183,15 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } - file, useUserClient, err := tryFetchFileFromMessage(ctx, linkChatID, messageID, "") - if err != nil { - common.Log.Errorf("获取文件失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil) - return dispatcher.EndGroups - } - if file.FileName == "" { - msg, err := tryFetchMessage(ctx, linkChatID, messageID) - if err != nil { - file.FileName = fmt.Sprintf("%d_%d", linkChatID, messageID) - } else { - file.FileName = GenFileNameFromMessage(*msg, file) - } - } - + // TODO: handle group files receivedFile := &dao.ReceivedFile{ Processing: false, - FileName: file.FileName, - ChatID: linkChatID, - MessageID: messageID, + FileName: result.Files[0].FileName, + ChatID: result.ChatID, + MessageID: result.MessageID, ReplyMessageID: replied.ID, ReplyChatID: update.GetUserChat().GetID(), - UseUserClient: useUserClient, + UseUserClient: result.UserClient, } record, err := dao.SaveReceivedFile(receivedFile) if err != nil { @@ -142,19 +202,20 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { }) return dispatcher.EndGroups } + file := result.Files[0] if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, file.FileName, linkChatID, messageID, replied.ID) + return ProvideSelectMessage(ctx, update, file.FileName, result.ChatID, result.MessageID, replied.ID) } return HandleSilentAddTask(ctx, update, user, &types.Task{ Ctx: ctx, Status: types.Pending, FileDBID: record.ID, - UseUserClient: useUserClient, + UseUserClient: result.UserClient, File: file, StorageName: user.DefaultStorage, UserID: user.ChatID, - FileChatID: linkChatID, - FileMessageID: messageID, + FileChatID: result.ChatID, + FileMessageID: result.MessageID, ReplyMessageID: replied.ID, ReplyChatID: update.GetUserChat().GetID(), }) diff --git a/bot/utils.go b/bot/utils.go index 556e634..f915054 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -237,6 +237,8 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e return tgMessage, nil } +// Userbot only +// // https://github.com/iyear/tdl/blob/fbb396da774ba544e527c3ef41c44921ad74ee98/core/util/tutil/tutil.go#L174 func GetSingleHistoryMessage(ctx context.Context, client *tg.Client, peer tg.InputPeerClass, msg int) (*tg.Message, error) { it := query.Messages(client).GetHistory(peer).OffsetID(msg + 1).BatchSize(1).Iter() @@ -256,6 +258,74 @@ func GetSingleHistoryMessage(ctx context.Context, client *tg.Client, peer tg.Inp return m, nil } +// Userbot only +func GetHistoryMessages(ctx context.Context, client *tg.Client, peer tg.InputPeerClass, startID, limit int) ([]*tg.Message, error) { + endID := startID + limit - 1 + msgs, err := client.MessagesGetHistory(ctx, &tg.MessagesGetHistoryRequest{ + Peer: peer, + OffsetID: startID, + Limit: limit, + AddOffset: startID - endID, + }) + if err != nil { + return nil, fmt.Errorf("failed to get history messages: %w", err) + } + var msgClass []tg.MessageClass + switch msgsv := msgs.(type) { + case *tg.MessagesMessages: + msgClass = msgsv.GetMessages() + case *tg.MessagesMessagesSlice: + msgClass = msgsv.GetMessages() + case *tg.MessagesChannelMessages: + msgClass = msgsv.GetMessages() + default: + return nil, fmt.Errorf("unexpected messages type: %T", msgs) + } + + messageBatch := make([]*tg.Message, 0, 100) + + for _, msg := range msgClass { + msgNotEmpty, ok := msg.AsNotEmpty() + if !ok { + continue + } + switch msgNotEmptyV := msgNotEmpty.(type) { + case *tg.Message: + messageBatch = append(messageBatch, msgNotEmptyV) + default: + common.Log.Warnf("Unexpected message type: %T, skipping", msgNotEmptyV) + continue + } + } + if len(messageBatch) == 0 { + return nil, fmt.Errorf("no messages found for peer %s with startID %d and limit %d", peer, startID, limit) + } + return messageBatch, nil +} + +func GetMediaGroup(ctx *ext.Context, chatID int64, messageID int, groupID int64) ([]*tg.Message, error) { + peer := ctx.PeerStorage.GetInputPeerById(chatID) + if peer == nil { + return nil, fmt.Errorf("无法获取聊天 %d 的输入 Peer", chatID) + } + messages, err := GetHistoryMessages(ctx, ctx.Raw, peer, messageID-9, 20) + if err != nil { + return nil, fmt.Errorf("获取消息失败: %s", err) + } + var groupMessages []*tg.Message + for _, msg := range messages { + common.Log.Debugf("Checking message %d in group %d", msg.ID, groupID) + gID, isGroup := msg.GetGroupedID() + if isGroup && gID == groupID { + groupMessages = append(groupMessages, msg) + } + } + if len(groupMessages) == 0 || (len(groupMessages) == 1 && groupMessages[0].ID == messageID) { + return nil, fmt.Errorf("未找到媒体组 %d 中的消息", groupID) + } + return groupMessages, nil +} + func GetInputPeerID(peer tg.InputPeerClass) int64 { switch p := peer.(type) { case *tg.InputPeerUser: diff --git a/dao/file.go b/dao/file.go index fac81e4..01fad70 100644 --- a/dao/file.go +++ b/dao/file.go @@ -9,6 +9,19 @@ func SaveReceivedFile(receivedFile *ReceivedFile) (*ReceivedFile, error) { return receivedFile, db.Error } +func BatchSaveReceivedFiles(receivedFiles []*ReceivedFile) error { + if len(receivedFiles) == 0 { + return nil + } + for _, file := range receivedFiles { + record, err := GetReceivedFileByChatAndMessageID(file.ChatID, file.MessageID) + if err == nil { + file.ID = record.ID + } + } + return db.Save(receivedFiles).Error +} + func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*ReceivedFile, error) { var receivedFile ReceivedFile err := db.Where("chat_id = ? AND message_id = ?", chatID, messageID).First(&receivedFile).Error