feat: add user client

This commit is contained in:
krau
2025-06-08 15:36:14 +08:00
parent 481427683e
commit c7c458f147
15 changed files with 501 additions and 45 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/queue"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/userclient"
"gorm.io/gorm"
)
@@ -153,7 +154,14 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
task.StoragePath = path.Join(dir.Path, record.FileName)
}
} else {
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
var file *types.File
var err error
if record.UseUserClient && userclient.UC != nil {
uctx := userclient.UC.CreateContext()
file, err = FileFromMessage(uctx, record.ChatID, record.MessageID, record.FileName)
} else {
file, err = FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
}
if err != nil {
common.Log.Errorf("获取消息中的文件失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
@@ -168,6 +176,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
task = types.Task{
Ctx: ctx,
Status: types.Pending,
UseUserClient: record.UseUserClient,
FileDBID: record.ID,
File: file,
StorageName: storageName,

View File

@@ -12,6 +12,7 @@ import (
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/types"
"github.com/krau/SaveAny-Bot/userclient"
)
var (
@@ -50,6 +51,31 @@ func parseLink(ctx *ext.Context, link string) (chatID int64, messageID int, err
return chatID, messageID, nil
}
// use passed ctx client to fetch file from message,
//
// if failed try using userclient
func tryFetchFileFromMessage(ctx *ext.Context, chatID int64, messageID int, fileName string) (*types.File, bool, error) {
file, err := FileFromMessage(ctx, chatID, messageID, fileName)
if err == nil {
return file, false, nil
}
if (strings.Contains(err.Error(), "peer not found") || strings.Contains(err.Error(), "unexpected message type")) && userclient.UC != nil {
common.Log.Warnf("无法获取文件 %d:%d, 尝试使用 userbot: %s", chatID, messageID, err)
uctx := userclient.GetCtx()
// TODO: 群组支持
file, err = FileFromMessage(uctx, chatID, messageID, fileName)
if err == nil {
return file, true, nil
}
return nil, true, err
}
return nil, false, err
}
func tryFetchMessage(ctx *ext.Context, chatID int64, messageID int) (*tg.Message, error) {
return GetTGMessage(ctx, chatID, messageID)
}
func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
common.Log.Trace("Got link message")
link := linkRegex.FindString(update.EffectiveMessage.Text)
@@ -70,26 +96,25 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
// storages := storage.GetUserStorages(user.ChatID)
// if len(storages) == 0 {
// ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
// return dispatcher.EndGroups
// }
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
if err != nil {
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
file, err := FileFromMessage(ctx, linkChatID, messageID, "")
file, useUserClient, err := tryFetchFileFromMessage(ctx, linkChatID, messageID, "")
if err != nil {
common.Log.Errorf("获取文件失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取文件失败: "+err.Error()), nil)
return dispatcher.EndGroups
}
if file.FileName == "" {
file.FileName = GenFileNameFromMessage(*update.EffectiveMessage.Message, file)
msg, err := tryFetchMessage(ctx, linkChatID, messageID)
if err != nil {
file.FileName = fmt.Sprintf("%d_%d", linkChatID, messageID)
} else {
file.FileName = GenFileNameFromMessage(*msg, file)
}
}
receivedFile := &dao.ReceivedFile{
@@ -99,6 +124,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
MessageID: messageID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
UseUserClient: useUserClient,
}
record, err := dao.SaveReceivedFile(receivedFile)
if err != nil {
@@ -116,6 +142,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
Ctx: ctx,
Status: types.Pending,
FileDBID: record.ID,
UseUserClient: useUserClient,
File: file,
StorageName: user.DefaultStorage,
UserID: user.ChatID,

View File

@@ -4,31 +4,38 @@ import (
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/dispatcher/handlers"
"github.com/celestix/gotgproto/dispatcher/handlers/filters"
"github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/common"
)
func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
dispatcher.AddHandler(handlers.NewMessage(filters.Message.All, checkPermission))
dispatcher.AddHandler(handlers.NewCommand("start", start))
dispatcher.AddHandler(handlers.NewCommand("help", help))
dispatcher.AddHandler(handlers.NewCommand("silent", silent))
dispatcher.AddHandler(handlers.NewCommand("storage", storageCmd))
dispatcher.AddHandler(handlers.NewCommand("save", saveCmd))
dispatcher.AddHandler(handlers.NewCommand("dir", dirCmd))
dispatcher.AddHandler(handlers.NewCommand("rule", ruleCmd))
func RegisterHandlers(disp dispatcher.Dispatcher) {
disp.AddHandler(handlers.NewMessage(filters.Message.ChatType(filters.ChatTypeChannel), func(ctx *ext.Context, u *ext.Update) error {
return dispatcher.EndGroups
}))
disp.AddHandler(handlers.NewMessage(filters.Message.ChatType(filters.ChatTypeChat), func(ctx *ext.Context, u *ext.Update) error {
return dispatcher.EndGroups
}))
disp.AddHandler(handlers.NewMessage(filters.Message.All, checkPermission))
disp.AddHandler(handlers.NewCommand("start", start))
disp.AddHandler(handlers.NewCommand("help", help))
disp.AddHandler(handlers.NewCommand("silent", silent))
disp.AddHandler(handlers.NewCommand("storage", storageCmd))
disp.AddHandler(handlers.NewCommand("save", saveCmd))
disp.AddHandler(handlers.NewCommand("dir", dirCmd))
disp.AddHandler(handlers.NewCommand("rule", ruleCmd))
linkRegexFilter, err := filters.Message.Regex(linkRegexString)
if err != nil {
common.Log.Panicf("创建正则表达式过滤器失败: %s", err)
}
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
disp.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
telegraphUrlRegexFilter, err := filters.Message.Regex(TelegraphUrlRegexString)
if err != nil {
common.Log.Panicf("创建 Telegraph URL 正则表达式过滤器失败: %s", err)
}
dispatcher.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleTelegraph))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("send_here"), sendFileToTelegram))
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
disp.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleTelegraph))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
disp.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("send_here"), sendFileToTelegram))
disp.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
}