feat: parse media group, wip

This commit is contained in:
krau
2025-06-09 16:17:27 +08:00
parent 693e20b066
commit 19535d0438
3 changed files with 187 additions and 43 deletions

View File

@@ -1,7 +1,9 @@
package bot
import (
"errors"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
@@ -16,39 +18,99 @@ import (
)
var (
linkRegexString = `t.me/.*/\d+`
linkRegexString = `https?://t\.me/(?:c/\d+|[a-zA-Z0-9_]+)/\d+(?:\?[^\s]*)?`
linkRegex = regexp.MustCompile(linkRegexString)
)
func parseLink(ctx *ext.Context, link string) (chatID int64, messageID int, err error) {
strSlice := strings.Split(link, "/")
if len(strSlice) < 3 {
return 0, 0, fmt.Errorf("链接格式错误: %s", link)
}
messageID, err = strconv.Atoi(strSlice[len(strSlice)-1])
type parseResult struct {
ChatID int64
MessageID int
Files []*types.File
UserClient bool
}
func parseLink(ctx *ext.Context, link string) (*parseResult, error) {
u, err := url.Parse(link)
if err != nil {
return 0, 0, fmt.Errorf("无法解析消息 ID: %s", err)
return nil, fmt.Errorf("无法解析链接: %s", err)
}
strSlice := strings.Split(u.Path, "/")
if len(strSlice) < 3 {
return nil, fmt.Errorf("链接格式错误: %s", link)
}
messageID, err := strconv.Atoi(strSlice[len(strSlice)-1])
if err != nil {
return nil, fmt.Errorf("无法解析消息 ID: %s", err)
}
var chatID int64
if len(strSlice) == 3 {
chatUsername := strSlice[1]
linkChat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
return 0, 0, fmt.Errorf("解析用户名失败: %s", err)
peer := ctx.PeerStorage.GetPeerByUsername(chatUsername)
if peer != nil {
chatID = peer.ID
} else {
linkChat, err := ctx.ResolveUsername(chatUsername)
if err != nil {
return nil, fmt.Errorf("解析用户名失败: %s", err)
}
if linkChat == nil {
return nil, fmt.Errorf("找不到该聊天: %s", chatUsername)
}
chatID = linkChat.GetID()
}
if linkChat == nil {
return 0, 0, fmt.Errorf("找不到该聊天: %s", chatUsername)
}
chatID = linkChat.GetID()
} else if len(strSlice) == 4 {
chatIDInt, err := strconv.Atoi(strSlice[2])
if err != nil {
return 0, 0, fmt.Errorf("无法解析 Chat ID: %s", err)
return nil, fmt.Errorf("无法解析 Chat ID: %s", err)
}
chatID = int64(chatIDInt)
} else {
return 0, 0, fmt.Errorf("无效的链接: %s", link)
return nil, errors.New("链接格式不正确,无法解析 Chat ID")
}
return chatID, messageID, nil
if chatID == 0 || messageID == 0 {
return nil, fmt.Errorf("链接中缺少 Chat ID 或 Message ID: %s", link)
}
msg, err := tryFetchMessage(ctx, chatID, messageID)
if err != nil {
return nil, fmt.Errorf("获取消息失败: %s", err)
}
mediaGroup, isGroup := msg.GetGroupedID()
if u.Query().Has("single") || !isGroup || (mediaGroup == 0) || userclient.UC == nil {
file, useUserClient, err := tryFetchFileFromMessage(ctx, chatID, messageID, "")
if err != nil {
return nil, fmt.Errorf("获取文件失败: %s", err)
}
if file.FileName == "" {
file.FileName = GenFileNameFromMessage(*msg, file)
}
return &parseResult{
ChatID: chatID,
MessageID: messageID,
Files: []*types.File{file},
UserClient: useUserClient,
}, nil
}
groupMessages, isUserClient, err := tryGetMediaGroup(chatID, messageID, mediaGroup)
if err != nil {
return nil, fmt.Errorf("获取媒体组消息失败: %s", err)
}
var files []*types.File
for _, groupMsg := range groupMessages {
file, err := FileFromMedia(groupMsg.Media, "")
if err != nil {
return nil, fmt.Errorf("获取媒体文件失败: %s", err)
}
if file.FileName == "" {
file.FileName = GenFileNameFromMessage(*groupMsg, file)
}
files = append(files, file)
}
return &parseResult{
ChatID: chatID,
MessageID: messageID,
Files: files,
UserClient: isUserClient,
}, nil
}
// use passed ctx client to fetch file from message,
@@ -79,6 +141,18 @@ func tryFetchFileFromMessage(ctx *ext.Context, chatID int64, messageID int, file
return nil, false, err
}
func tryGetMediaGroup(chatID int64, messageID int, mediaGroupID int64) ([]*tg.Message, bool, error) {
if userclient.UC != nil {
uctx := userclient.GetCtx()
messages, err := GetMediaGroup(uctx, chatID, messageID, mediaGroupID)
if err != nil {
return nil, true, fmt.Errorf("failed to get media group from userbot: %w", err)
}
return messages, true, nil
}
return nil, false, errors.New("userclient is not available, cannot fetch media group")
}
func tryFetchMessage(ctx *ext.Context, chatID int64, messageID int) (*tg.Message, error) {
return GetTGMessage(ctx, chatID, messageID)
}
@@ -89,10 +163,10 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
if link == "" {
return dispatcher.ContinueGroups
}
linkChatID, messageID, err := parseLink(ctx, link)
result, err := parseLink(ctx, link)
if err != nil {
common.Log.Errorf("解析链接失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("解析链接失败: "+err.Error()), nil)
ctx.Reply(update, ext.ReplyTextString("解析链接失败"), nil)
return dispatcher.EndGroups
}
@@ -109,29 +183,15 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
file, useUserClient, err := tryFetchFileFromMessage(ctx, linkChatID, messageID, "")
if err != nil {
common.Log.Errorf("获取文件失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
if file.FileName == "" {
msg, err := tryFetchMessage(ctx, linkChatID, messageID)
if err != nil {
file.FileName = fmt.Sprintf("%d_%d", linkChatID, messageID)
} else {
file.FileName = GenFileNameFromMessage(*msg, file)
}
}
// TODO: handle group files
receivedFile := &dao.ReceivedFile{
Processing: false,
FileName: file.FileName,
ChatID: linkChatID,
MessageID: messageID,
FileName: result.Files[0].FileName,
ChatID: result.ChatID,
MessageID: result.MessageID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
UseUserClient: useUserClient,
UseUserClient: result.UserClient,
}
record, err := dao.SaveReceivedFile(receivedFile)
if err != nil {
@@ -142,19 +202,20 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
})
return dispatcher.EndGroups
}
file := result.Files[0]
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file.FileName, linkChatID, messageID, replied.ID)
return ProvideSelectMessage(ctx, update, file.FileName, result.ChatID, result.MessageID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
FileDBID: record.ID,
UseUserClient: useUserClient,
UseUserClient: result.UserClient,
File: file,
StorageName: user.DefaultStorage,
UserID: user.ChatID,
FileChatID: linkChatID,
FileMessageID: messageID,
FileChatID: result.ChatID,
FileMessageID: result.MessageID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
})

View File

@@ -237,6 +237,8 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
return tgMessage, nil
}
// Userbot only
//
// https://github.com/iyear/tdl/blob/fbb396da774ba544e527c3ef41c44921ad74ee98/core/util/tutil/tutil.go#L174
func GetSingleHistoryMessage(ctx context.Context, client *tg.Client, peer tg.InputPeerClass, msg int) (*tg.Message, error) {
it := query.Messages(client).GetHistory(peer).OffsetID(msg + 1).BatchSize(1).Iter()
@@ -256,6 +258,74 @@ func GetSingleHistoryMessage(ctx context.Context, client *tg.Client, peer tg.Inp
return m, nil
}
// Userbot only
func GetHistoryMessages(ctx context.Context, client *tg.Client, peer tg.InputPeerClass, startID, limit int) ([]*tg.Message, error) {
endID := startID + limit - 1
msgs, err := client.MessagesGetHistory(ctx, &tg.MessagesGetHistoryRequest{
Peer: peer,
OffsetID: startID,
Limit: limit,
AddOffset: startID - endID,
})
if err != nil {
return nil, fmt.Errorf("failed to get history messages: %w", err)
}
var msgClass []tg.MessageClass
switch msgsv := msgs.(type) {
case *tg.MessagesMessages:
msgClass = msgsv.GetMessages()
case *tg.MessagesMessagesSlice:
msgClass = msgsv.GetMessages()
case *tg.MessagesChannelMessages:
msgClass = msgsv.GetMessages()
default:
return nil, fmt.Errorf("unexpected messages type: %T", msgs)
}
messageBatch := make([]*tg.Message, 0, 100)
for _, msg := range msgClass {
msgNotEmpty, ok := msg.AsNotEmpty()
if !ok {
continue
}
switch msgNotEmptyV := msgNotEmpty.(type) {
case *tg.Message:
messageBatch = append(messageBatch, msgNotEmptyV)
default:
common.Log.Warnf("Unexpected message type: %T, skipping", msgNotEmptyV)
continue
}
}
if len(messageBatch) == 0 {
return nil, fmt.Errorf("no messages found for peer %s with startID %d and limit %d", peer, startID, limit)
}
return messageBatch, nil
}
func GetMediaGroup(ctx *ext.Context, chatID int64, messageID int, groupID int64) ([]*tg.Message, error) {
peer := ctx.PeerStorage.GetInputPeerById(chatID)
if peer == nil {
return nil, fmt.Errorf("无法获取聊天 %d 的输入 Peer", chatID)
}
messages, err := GetHistoryMessages(ctx, ctx.Raw, peer, messageID-9, 20)
if err != nil {
return nil, fmt.Errorf("获取消息失败: %s", err)
}
var groupMessages []*tg.Message
for _, msg := range messages {
common.Log.Debugf("Checking message %d in group %d", msg.ID, groupID)
gID, isGroup := msg.GetGroupedID()
if isGroup && gID == groupID {
groupMessages = append(groupMessages, msg)
}
}
if len(groupMessages) == 0 || (len(groupMessages) == 1 && groupMessages[0].ID == messageID) {
return nil, fmt.Errorf("未找到媒体组 %d 中的消息", groupID)
}
return groupMessages, nil
}
func GetInputPeerID(peer tg.InputPeerClass) int64 {
switch p := peer.(type) {
case *tg.InputPeerUser:

View File

@@ -9,6 +9,19 @@ func SaveReceivedFile(receivedFile *ReceivedFile) (*ReceivedFile, error) {
return receivedFile, db.Error
}
func BatchSaveReceivedFiles(receivedFiles []*ReceivedFile) error {
if len(receivedFiles) == 0 {
return nil
}
for _, file := range receivedFiles {
record, err := GetReceivedFileByChatAndMessageID(file.ChatID, file.MessageID)
if err == nil {
file.ID = record.ID
}
}
return db.Save(receivedFiles).Error
}
func GetReceivedFileByChatAndMessageID(chatID int64, messageID int) (*ReceivedFile, error) {
var receivedFile ReceivedFile
err := db.Where("chat_id = ? AND message_id = ?", chatID, messageID).First(&receivedFile).Error