From 3e3a3206721bca19d97314689fed9556d3574977 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Sat, 22 Mar 2025 11:52:43 +0800 Subject: [PATCH] feat: download telegraph images , close #5 --- bot/bot.go | 1 + bot/handle_add_task.go | 69 +++++++++++++++--------- bot/handle_file.go | 2 +- bot/handle_link.go | 2 +- bot/handle_save.go | 2 +- bot/handle_telegraph.go | 114 ++++++++++++++++++++++++++++++++++++++++ bot/handlers.go | 5 ++ bot/utils.go | 15 ++++-- common/cache.go | 6 ++- core/download.go | 111 +++++++++++++++++++++++++++++++++++++- core/utils.go | 4 +- dao/model.go | 2 + go.mod | 1 + go.sum | 2 + types/types.go | 25 +++++++-- types/utils.go | 12 +++++ 16 files changed, 333 insertions(+), 40 deletions(-) create mode 100644 bot/handle_telegraph.go create mode 100644 types/utils.go diff --git a/bot/bot.go b/bot/bot.go index 9e2558e..45cc886 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -27,6 +27,7 @@ func newProxyDialer(proxyUrl string) (proxy.Dialer, error) { } func Init() { + InitTelegraphClient() common.Log.Info("初始化 Telegram 客户端...") ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() diff --git a/bot/handle_add_task.go b/bot/handle_add_task.go index 5f77d90..8b9cbd6 100644 --- a/bot/handle_add_task.go +++ b/bot/handle_add_task.go @@ -33,7 +33,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } args := strings.Split(string(update.CallbackQuery.Data), " ") - addToDir := args[0] == "add_to_dir" + addToDir := args[0] == "add_to_dir" // 已经选择了路径 cbDataId, _ := strconv.Atoi(args[1]) cbData, err := dao.GetCallbackData(uint(cbDataId)) if err != nil { @@ -136,31 +136,50 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error { } } - file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName) - if err != nil { - common.Log.Errorf("获取消息中的文件失败: %s", err) - ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ - QueryID: update.CallbackQuery.QueryID, - Alert: true, - Message: fmt.Sprintf("获取消息中的文件失败: %s", err), - CacheTime: 5, - }) - return dispatcher.EndGroups - } + var task types.Task + if record.IsTelegraph { + task = types.Task{ + Ctx: ctx, + Status: types.Pending, + IsTelegraph: true, + TelegraphURL: record.TelegraphURL, + StorageName: storageName, + FileChatID: record.ChatID, + FileMessageID: record.MessageID, + ReplyMessageID: record.ReplyMessageID, + ReplyChatID: record.ReplyChatID, + UserID: update.GetUserChat().GetID(), + } + if dir != nil { + task.StoragePath = path.Join(dir.Path, record.FileName) + } + } else { + file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName) + if err != nil { + common.Log.Errorf("获取消息中的文件失败: %s", err) + ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{ + QueryID: update.CallbackQuery.QueryID, + Alert: true, + Message: fmt.Sprintf("获取消息中的文件失败: %s", err), + CacheTime: 5, + }) + return dispatcher.EndGroups + } - task := types.Task{ - Ctx: ctx, - Status: types.Pending, - File: file, - StorageName: storageName, - FileChatID: record.ChatID, - ReplyMessageID: record.ReplyMessageID, - FileMessageID: record.MessageID, - ReplyChatID: record.ReplyChatID, - UserID: update.GetUserChat().GetID(), - } - if dir != nil { - task.StoragePath = path.Join(dir.Path, file.FileName) + task = types.Task{ + Ctx: ctx, + Status: types.Pending, + File: file, + StorageName: storageName, + FileChatID: record.ChatID, + ReplyMessageID: record.ReplyMessageID, + FileMessageID: record.MessageID, + ReplyChatID: record.ReplyChatID, + UserID: update.GetUserChat().GetID(), + } + if dir != nil { + task.StoragePath = path.Join(dir.Path, file.FileName) + } } queue.AddTask(&task) diff --git a/bot/handle_file.go b/bot/handle_file.go index b140cba..96dec60 100644 --- a/bot/handle_file.go +++ b/bot/handle_file.go @@ -69,7 +69,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error { } if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID) + return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID) } return HandleSilentAddTask(ctx, update, user, &types.Task{ Ctx: ctx, diff --git a/bot/handle_link.go b/bot/handle_link.go index ef75a0e..a62cf16 100644 --- a/bot/handle_link.go +++ b/bot/handle_link.go @@ -92,7 +92,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, file, linkChat.GetID(), messageID, replied.ID) + return ProvideSelectMessage(ctx, update, file.FileName, linkChat.GetID(), messageID, replied.ID) } return HandleSilentAddTask(ctx, update, user, &types.Task{ Ctx: ctx, diff --git a/bot/handle_save.go b/bot/handle_save.go index 3c477a0..43dbcff 100644 --- a/bot/handle_save.go +++ b/bot/handle_save.go @@ -99,7 +99,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error { return dispatcher.EndGroups } if !user.Silent || user.DefaultStorage == "" { - return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID) + return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), msg.ID, replied.ID) } return HandleSilentAddTask(ctx, update, user, &types.Task{ Ctx: ctx, diff --git a/bot/handle_telegraph.go b/bot/handle_telegraph.go new file mode 100644 index 0000000..40d61ae --- /dev/null +++ b/bot/handle_telegraph.go @@ -0,0 +1,114 @@ +package bot + +import ( + "fmt" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/celestix/telegraph-go/v2" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/dao" + "github.com/krau/SaveAny-Bot/storage" + "github.com/krau/SaveAny-Bot/types" +) + +var ( + TelegraphClient *telegraph.TelegraphClient + TelegraphUrlRegexString = `https://telegra.ph/.*` + TelegraphUrlRegex = regexp.MustCompile(TelegraphUrlRegexString) +) + +func InitTelegraphClient() { + var httpClient *http.Client + if config.Cfg.Telegram.Proxy.Enable { + proxyUrl, err := url.Parse(config.Cfg.Telegram.Proxy.URL) + if err != nil { + fmt.Println("Error parsing proxy URL:", err) + return + } + proxy := http.ProxyURL(proxyUrl) + httpClient = &http.Client{ + Transport: &http.Transport{ + Proxy: proxy, + }, + Timeout: 30 * time.Second, + } + } else { + httpClient = &http.Client{ + Timeout: 30 * time.Second, + } + } + TelegraphClient = telegraph.GetTelegraphClient(&telegraph.ClientOpt{HttpClient: httpClient}) +} + +func handleTelegraph(ctx *ext.Context, update *ext.Update) error { + common.Log.Trace("Got telegraph link") + tgphUrl := TelegraphUrlRegex.FindString(update.EffectiveMessage.Text) + if tgphUrl == "" { + return dispatcher.ContinueGroups + } + replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil) + if err != nil { + common.Log.Errorf("回复失败: %s", err) + return dispatcher.EndGroups + } + user, err := dao.GetUserByChatID(update.GetUserChat().GetID()) + if err != nil { + common.Log.Errorf("获取用户失败: %s", err) + ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil) + return dispatcher.EndGroups + } + storages := storage.GetUserStorages(user.ChatID) + + if len(storages) == 0 { + ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil) + return dispatcher.EndGroups + } + + tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1] + fileName, err := url.PathUnescape(tgphPath) + if err != nil { + common.Log.Errorf("解析 Telegraph 路径失败: %s", err) + fileName = tgphPath + } + + record := &dao.ReceivedFile{ + Processing: false, + FileName: fileName, + ChatID: update.EffectiveChat().GetID(), + MessageID: update.EffectiveMessage.GetID(), + ReplyMessageID: replied.ID, + ReplyChatID: update.EffectiveChat().GetID(), + IsTelegraph: true, + TelegraphURL: tgphUrl, + } + if err := dao.SaveReceivedFile(record); err != nil { + common.Log.Errorf("保存接收的文件失败: %s", err) + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + Message: "无法保存文件: " + err.Error(), + ID: replied.ID, + }) + return dispatcher.EndGroups + } + + if !user.Silent || user.DefaultStorage == "" { + return ProvideSelectMessage(ctx, update, fileName, update.EffectiveChat().GetID(), update.EffectiveMessage.GetID(), replied.ID) + } + return HandleSilentAddTask(ctx, update, user, &types.Task{ + Ctx: ctx, + Status: types.Pending, + StorageName: user.DefaultStorage, + UserID: user.ChatID, + ReplyMessageID: replied.ID, + ReplyChatID: update.GetUserChat().GetID(), + IsTelegraph: true, + TelegraphURL: tgphUrl, + }) +} diff --git a/bot/handlers.go b/bot/handlers.go index 84beaf8..410d363 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -20,6 +20,11 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) { common.Log.Panicf("创建正则表达式过滤器失败: %s", err) } dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage)) + telegraphUrlRegexFilter, err := filters.Message.Regex(TelegraphUrlRegexString) + if err != nil { + common.Log.Panicf("创建 Telegraph URL 正则表达式过滤器失败: %s", err) + } + dispatcher.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleTelegraph)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage)) dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask)) diff --git a/bot/utils.go b/bot/utils.go index 3902f9f..9fbab98 100644 --- a/bot/utils.go +++ b/bot/utils.go @@ -200,7 +200,13 @@ func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileNa } func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) { + key := fmt.Sprintf("message:%d:%d", chatId, messageID) common.Log.Debugf("Fetching message: %d", messageID) + var cachedMessage tg.Message + err := common.Cache.Get(key, &cachedMessage) + if err == nil { + return &cachedMessage, nil + } messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}}) if err != nil { return nil, err @@ -213,16 +219,19 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e if !ok { return nil, fmt.Errorf("unexpected message type: %T", msg) } + if err := common.Cache.Set(key, tgMessage, 3600); err != nil { + common.Log.Errorf("Failed to cache message: %s", err) + } return tgMessage, nil } -func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int64, fileMsgID, toEditMsgID int) error { +func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, fileName string, chatID int64, fileMsgID, toEditMsgID int) error { entityBuilder := entity.Builder{} var entities []tg.MessageEntityClass - text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName) + text := fmt.Sprintf("文件名: %s\n请选择存储位置", fileName) if err := styling.Perform(&entityBuilder, styling.Plain("文件名: "), - styling.Code(file.FileName), + styling.Code(fileName), styling.Plain("\n请选择存储位置"), ); err != nil { common.Log.Errorf("Failed to build entity: %s", err) diff --git a/common/cache.go b/common/cache.go index 657bb66..9f533d9 100644 --- a/common/cache.go +++ b/common/cache.go @@ -21,10 +21,12 @@ func initCache() { gob.Register(types.File{}) gob.Register(tg.InputDocumentFileLocation{}) gob.Register(tg.InputPhotoFileLocation{}) + gob.Register(tg.Message{}) + gob.Register(tg.PeerUser{}) Cache = &CommonCache{cache: freecache.NewCache(10 * 1024 * 1024)} } -func (c *CommonCache) Get(key string, value *types.File) error { +func (c *CommonCache) Get(key string, value any) error { c.mu.RLock() defer c.mu.RUnlock() data, err := Cache.cache.Get([]byte(key)) @@ -39,7 +41,7 @@ func (c *CommonCache) Get(key string, value *types.File) error { return nil } -func (c *CommonCache) Set(key string, value *types.File, expireSeconds int) error { +func (c *CommonCache) Set(key string, value any, expireSeconds int) error { c.mu.Lock() defer c.mu.Unlock() var buf bytes.Buffer diff --git a/core/download.go b/core/download.go index 36b6e08..c2f1474 100644 --- a/core/download.go +++ b/core/download.go @@ -2,12 +2,17 @@ package core import ( "context" + "encoding/json" "fmt" "io" + "net/http" + "path" "path/filepath" + "strings" "time" "github.com/celestix/gotgproto/ext" + "github.com/celestix/telegraph-go/v2" "github.com/duke-git/lancet/v2/fileutil" "github.com/gotd/td/tg" "github.com/krau/SaveAny-Bot/bot" @@ -25,7 +30,7 @@ func processPendingTask(task *types.Task) error { } if task.StoragePath == "" { - task.StoragePath = task.File.FileName + task.StoragePath = task.FileName() } taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName) @@ -42,6 +47,10 @@ func processPendingTask(task *types.Task) error { cancelCtx, cancel := context.WithCancel(ctx) task.Cancel = cancel + if task.IsTelegraph { + return processTelegraph(ctx, cancelCtx, task, taskStorage) + } + if task.File.FileSize == 0 { return processPhoto(task, taskStorage) } @@ -123,5 +132,103 @@ func processPendingTask(task *types.Task) error { ID: task.ReplyMessageID, }) - return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath) + return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath) +} + +func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *types.Task, taskStorage storage.Storage) error { + if bot.TelegraphClient == nil { + return fmt.Errorf("telegraph client is not initialized") + } + tgphUrl := task.TelegraphURL + tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1] + if tgphUrl == "" || tgphPath == "" { + return fmt.Errorf("invalid telegraph url") + } + + resultCh := make(chan error) + go func() { + page, err := bot.TelegraphClient.GetPage(tgphPath, true) + if err != nil { + resultCh <- fmt.Errorf("获取 telegraph 页面失败: %w", err) + return + } + imgs := make([]string, 0) + for _, element := range page.Content { + var node telegraph.NodeElement + data, err := json.Marshal(element) + if err != nil { + common.Log.Errorf("Failed to marshal element: %s", err) + continue + } + err = json.Unmarshal(data, &node) + if err != nil { + common.Log.Errorf("Failed to unmarshal element: %s", err) + continue + } + if node.Tag == "img" { + if src, ok := node.Attrs["src"]; ok { + imgs = append(imgs, src) + } + } + + } + if len(imgs) == 0 { + resultCh <- fmt.Errorf("没有找到图片") + return + } + hc := bot.TelegraphClient.HttpClient + eg, ectx := errgroup.WithContext(cancelCtx) + eg.SetLimit(config.Cfg.Workers) // TODO: use a new config field for this + for i, img := range imgs { + eg.Go(func() error { + var lastErr error + for attempt := range config.Cfg.Retry { + if attempt > 0 { + retryDelay := time.Duration(attempt*attempt) * time.Second + select { + case <-ectx.Done(): + return ectx.Err() + case <-time.After(retryDelay): + } + common.Log.Debugf("Retrying to download image %s (attempt %d)", img, attempt+1) + } + req, err := http.NewRequestWithContext(ectx, http.MethodGet, img, nil) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + resp, err := hc.Do(req) + if err != nil { + lastErr = fmt.Errorf("发送请求失败: %w", err) + continue + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("请求图片失败: %s", resp.Status) + continue + } + targetPath := path.Join(task.StoragePath, fmt.Sprintf("%d%s", i+1, path.Ext(img))) + err = taskStorage.Save(ectx, resp.Body, targetPath) + if err != nil { + lastErr = fmt.Errorf("保存图片失败: %w", err) + continue + } + common.Log.Infof("Saved image: %s", targetPath) + return nil + } + return lastErr + }) + } + if err := eg.Wait(); err != nil { + resultCh <- err + return + } + resultCh <- nil + }() + select { + case err := <-resultCh: + return err + case <-cancelCtx.Done(): + return cancelCtx.Err() + } } diff --git a/core/utils.go b/core/utils.go index ad882d7..4d724e9 100644 --- a/core/utils.go +++ b/core/utils.go @@ -21,7 +21,7 @@ import ( "github.com/krau/SaveAny-Bot/types" ) -func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, cacheFilePath string) error { +func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error { for i := 0; i <= config.Cfg.Retry; i++ { if err := ctx.Err(); err != nil { return fmt.Errorf("context canceled while saving file: %w", err) @@ -30,7 +30,7 @@ func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storag if err != nil { return fmt.Errorf("failed to open cache file: %w", err) } - if err := taskStorage.Save(ctx, file, task.StoragePath); err != nil { + if err := taskStorage.Save(ctx, file, storagePath); err != nil { if i == config.Cfg.Retry { return fmt.Errorf("failed to save file: %w", err) } diff --git a/dao/model.go b/dao/model.go index 4d95cd4..d1668c5 100644 --- a/dao/model.go +++ b/dao/model.go @@ -14,6 +14,8 @@ type ReceivedFile struct { ReplyMessageID int ReplyChatID int64 FileName string + IsTelegraph bool + TelegraphURL string } type User struct { diff --git a/go.mod b/go.mod index c1affaa..4996026 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.5 require ( github.com/blang/semver v3.5.1+incompatible github.com/celestix/gotgproto v1.0.0-beta20.2 + github.com/celestix/telegraph-go/v2 v2.0.4 github.com/gabriel-vasile/mimetype v1.4.8 github.com/gookit/slog v0.5.7 github.com/gotd/contrib v0.21.0 diff --git a/go.sum b/go.sum index bce7629..00d749e 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdn github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/celestix/gotgproto v1.0.0-beta20.2 h1:+WcsKdsyj4xy+TAV+4Sw6zp1xiQrIr4dMnM31+k8NYM= github.com/celestix/gotgproto v1.0.0-beta20.2/go.mod h1:j42ZhBMUke6QyBLvCgx8tA+TL9L3+pq/Q46B+b5+3aU= +github.com/celestix/telegraph-go/v2 v2.0.4 h1:w8HWymJFhMSMPjdGoyTh3/NqE3eXAT1njTvelh0338k= +github.com/celestix/telegraph-go/v2 v2.0.4/go.mod h1:vu2LtqM7MgOAJ2LDF8XK27DWdd1QYLBfZGhalEh086Y= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= diff --git a/types/types.go b/types/types.go index 1c32eaa..5f06acc 100644 --- a/types/types.go +++ b/types/types.go @@ -5,6 +5,8 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "net/url" + "strings" "time" "github.com/gotd/td/tg" @@ -41,29 +43,46 @@ type Task struct { Cancel context.CancelFunc Error error Status TaskStatus - File *File StorageName string StoragePath string StartTime time.Time + File *File FileMessageID int FileChatID int64 + + IsTelegraph bool + TelegraphURL string + // to track the reply message ReplyMessageID int ReplyChatID int64 - // to track the user - UserID int64 + UserID int64 } func (t Task) Key() string { + if t.IsTelegraph { + return hashStr(t.TelegraphURL) + } return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID) } func (t Task) String() string { + if t.IsTelegraph { + return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL) + } return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName) } func (t Task) FileName() string { + if t.IsTelegraph { + tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1] + tgphPathUnescaped, err := url.PathUnescape(tgphPath) + if err != nil { + return tgphPath + } + return tgphPathUnescaped + } return t.File.FileName } diff --git a/types/utils.go b/types/utils.go new file mode 100644 index 0000000..85ab6f4 --- /dev/null +++ b/types/utils.go @@ -0,0 +1,12 @@ +package types + +import ( + "crypto/md5" + "encoding/hex" +) + +func hashStr(s string) string { + hash := md5.New() + hash.Write([]byte(s)) + return hex.EncodeToString(hash.Sum(nil)) +}