diff --git a/bot/handle_link.go b/bot/handle_link.go index f9fd55f..3fb36e2 100644 --- a/bot/handle_link.go +++ b/bot/handle_link.go @@ -1,6 +1,7 @@ package bot import ( + "fmt" "regexp" "strconv" "strings" @@ -18,47 +19,47 @@ var ( linkRegex = regexp.MustCompile(linkRegexString) ) +func parseLink(ctx *ext.Context, link string) (chatID int64, messageID int, err error) { + strSlice := strings.Split(link, "/") + if len(strSlice) < 3 { + return 0, 0, fmt.Errorf("链接格式错误: %s", link) + } + messageID, err = strconv.Atoi(strSlice[len(strSlice)-1]) + if err != nil { + return 0, 0, fmt.Errorf("无法解析消息 ID: %s", err) + } + if len(strSlice) == 3 { + chatUsername := strSlice[1] + linkChat, err := ctx.ResolveUsername(chatUsername) + if err != nil { + return 0, 0, fmt.Errorf("解析用户名失败: %s", err) + } + if linkChat == nil { + return 0, 0, fmt.Errorf("找不到该聊天: %s", chatUsername) + } + chatID = linkChat.GetID() + } else if len(strSlice) == 4 { + chatIDInt, err := strconv.Atoi(strSlice[2]) + if err != nil { + return 0, 0, fmt.Errorf("无法解析 Chat ID: %s", err) + } + chatID = int64(chatIDInt) + } else { + return 0, 0, fmt.Errorf("无效的链接: %s", link) + } + return chatID, messageID, nil +} + func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { common.Log.Trace("Got link message") link := linkRegex.FindString(update.EffectiveMessage.Text) if link == "" { return dispatcher.ContinueGroups } - strSlice := strings.Split(link, "/") - if len(strSlice) < 3 { - return dispatcher.ContinueGroups - } - messageID, err := strconv.Atoi(strSlice[len(strSlice)-1]) + linkChatID, messageID, err := parseLink(ctx, link) if err != nil { - common.Log.Errorf("解析消息 ID 失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("无法解析消息 ID"), nil) - return dispatcher.EndGroups - } - var linkChatID int64 - if len(strSlice) == 3 { - chatUsername := strSlice[1] - linkChat, err := ctx.ResolveUsername(chatUsername) - if err != nil { - common.Log.Errorf("解析用户名失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("解析用户名失败"), nil) - return dispatcher.EndGroups - } - if linkChat == nil { - common.Log.Errorf("无法找到聊天: %s", chatUsername) - ctx.Reply(update, ext.ReplyTextString("无法找到聊天"), nil) - return dispatcher.EndGroups - } - linkChatID = linkChat.GetID() - } else if len(strSlice) == 4 { - chatID, err := strconv.Atoi(strSlice[2]) - if err != nil { - common.Log.Errorf("解析 Chat ID 失败: %s", err) - ctx.Reply(update, ext.ReplyTextString("解析 Chat ID 失败"), nil) - return dispatcher.EndGroups - } - linkChatID = int64(chatID) - } else { - ctx.Reply(update, ext.ReplyTextString("无法解析链接"), nil) + common.Log.Errorf("解析链接失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("解析链接失败: "+err.Error()), nil) return dispatcher.EndGroups } diff --git a/bot/middlewares.go b/bot/middlewares.go index 715272a..b994141 100644 --- a/bot/middlewares.go +++ b/bot/middlewares.go @@ -9,6 +9,7 @@ import ( "github.com/gotd/contrib/middleware/floodwait" "github.com/gotd/contrib/middleware/ratelimit" "github.com/gotd/td/telegram" + "github.com/krau/SaveAny-Bot/common" "github.com/krau/SaveAny-Bot/config" "golang.org/x/time/rate" ) @@ -30,8 +31,38 @@ const noPermissionText string = ` func checkPermission(ctx *ext.Context, update *ext.Update) error { userID := update.GetUserChat().GetID() if !slice.Contain(config.Cfg.GetUsersID(), userID) { + if config.Cfg.AsPublicCopyMediaBot { + tryCopyMedia(ctx, update) + return dispatcher.EndGroups + } ctx.Reply(update, ext.ReplyTextString(noPermissionText), nil) return dispatcher.EndGroups } return dispatcher.ContinueGroups } + +func tryCopyMedia(ctx *ext.Context, update *ext.Update) { + if !config.Cfg.AsPublicCopyMediaBot { + return + } + if update.EffectiveMessage == nil || update.EffectiveMessage.Message == nil || update.EffectiveMessage.Media == nil { + return + } + common.Log.Tracef("Got media from %d: %s", update.EffectiveChat().GetID(), update.EffectiveMessage.Media.TypeName()) + msg := update.EffectiveMessage.Message + if link := linkRegex.FindString(update.EffectiveMessage.Text); link != "" { + linkChatID, messageID, err := parseLink(ctx, link) + if err != nil { + return + } + fileMessage, err := GetTGMessage(ctx, linkChatID, messageID) + if err != nil { + return + } + if fileMessage == nil || fileMessage.Media == nil { + return + } + msg = fileMessage + } + copyMediaToChat(ctx, msg, update.EffectiveChat().GetID()) +} diff --git a/bot/utils.go b/bot/utils.go index bf9a1b4..88b0741 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -216,7 +216,7 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e if err == nil { return cacheMessage, nil } - common.Log.Debugf("Fetching message: %d", messageID) + common.Log.Debugf("Fetching message: %d:%d", chatId, messageID) messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}}) if err != nil { return nil, err diff --git a/config/viper.go b/config/viper.go index 112f878..29feacd 100644 --- a/config/viper.go +++ b/config/viper.go @@ -17,6 +17,9 @@ type Config struct { Threads int `toml:"threads" mapstructure:"threads" json:"threads"` Stream bool `toml:"stream" mapstructure:"stream" json:"stream"` + // Experimental: 将拷贝媒体文件的功能设为公开可用 + AsPublicCopyMediaBot bool `toml:"as_public_copy_media_bot" mapstructure:"as_public_copy_media_bot" json:"as_public_copy_media_bot"` + Users []userConfig `toml:"users" mapstructure:"users" json:"users"` Temp tempConfig `toml:"temp" mapstructure:"temp"`