feat: download telegraph images , close #5

This commit is contained in:
krau
2025-03-22 11:52:43 +08:00
parent 19efab0665
commit 3e3a320672
16 changed files with 333 additions and 40 deletions

View File

@@ -27,6 +27,7 @@ func newProxyDialer(proxyUrl string) (proxy.Dialer, error) {
}
func Init() {
InitTelegraphClient()
common.Log.Info("初始化 Telegram 客户端...")
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()

View File

@@ -33,7 +33,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
args := strings.Split(string(update.CallbackQuery.Data), " ")
addToDir := args[0] == "add_to_dir"
addToDir := args[0] == "add_to_dir" // 已经选择了路径
cbDataId, _ := strconv.Atoi(args[1])
cbData, err := dao.GetCallbackData(uint(cbDataId))
if err != nil {
@@ -136,31 +136,50 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
}
}
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
if err != nil {
common.Log.Errorf("获取消息中的文件失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
CacheTime: 5,
})
return dispatcher.EndGroups
}
var task types.Task
if record.IsTelegraph {
task = types.Task{
Ctx: ctx,
Status: types.Pending,
IsTelegraph: true,
TelegraphURL: record.TelegraphURL,
StorageName: storageName,
FileChatID: record.ChatID,
FileMessageID: record.MessageID,
ReplyMessageID: record.ReplyMessageID,
ReplyChatID: record.ReplyChatID,
UserID: update.GetUserChat().GetID(),
}
if dir != nil {
task.StoragePath = path.Join(dir.Path, record.FileName)
}
} else {
file, err := FileFromMessage(ctx, record.ChatID, record.MessageID, record.FileName)
if err != nil {
common.Log.Errorf("获取消息中的文件失败: %s", err)
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
QueryID: update.CallbackQuery.QueryID,
Alert: true,
Message: fmt.Sprintf("获取消息中的文件失败: %s", err),
CacheTime: 5,
})
return dispatcher.EndGroups
}
task := types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageName: storageName,
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,
ReplyChatID: record.ReplyChatID,
UserID: update.GetUserChat().GetID(),
}
if dir != nil {
task.StoragePath = path.Join(dir.Path, file.FileName)
task = types.Task{
Ctx: ctx,
Status: types.Pending,
File: file,
StorageName: storageName,
FileChatID: record.ChatID,
ReplyMessageID: record.ReplyMessageID,
FileMessageID: record.MessageID,
ReplyChatID: record.ReplyChatID,
UserID: update.GetUserChat().GetID(),
}
if dir != nil {
task.StoragePath = path.Join(dir.Path, file.FileName)
}
}
queue.AddTask(&task)

View File

@@ -69,7 +69,7 @@ func handleFileMessage(ctx *ext.Context, update *ext.Update) error {
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), update.EffectiveMessage.ID, msg.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,

View File

@@ -92,7 +92,7 @@ func handleLinkMessage(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file, linkChat.GetID(), messageID, replied.ID)
return ProvideSelectMessage(ctx, update, file.FileName, linkChat.GetID(), messageID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,

View File

@@ -99,7 +99,7 @@ func saveCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, file, update.EffectiveChat().GetID(), msg.ID, replied.ID)
return ProvideSelectMessage(ctx, update, file.FileName, update.EffectiveChat().GetID(), msg.ID, replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,

114
bot/handle_telegraph.go Normal file
View File

@@ -0,0 +1,114 @@
package bot
import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/dao"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
var (
TelegraphClient *telegraph.TelegraphClient
TelegraphUrlRegexString = `https://telegra.ph/.*`
TelegraphUrlRegex = regexp.MustCompile(TelegraphUrlRegexString)
)
func InitTelegraphClient() {
var httpClient *http.Client
if config.Cfg.Telegram.Proxy.Enable {
proxyUrl, err := url.Parse(config.Cfg.Telegram.Proxy.URL)
if err != nil {
fmt.Println("Error parsing proxy URL:", err)
return
}
proxy := http.ProxyURL(proxyUrl)
httpClient = &http.Client{
Transport: &http.Transport{
Proxy: proxy,
},
Timeout: 30 * time.Second,
}
} else {
httpClient = &http.Client{
Timeout: 30 * time.Second,
}
}
TelegraphClient = telegraph.GetTelegraphClient(&telegraph.ClientOpt{HttpClient: httpClient})
}
func handleTelegraph(ctx *ext.Context, update *ext.Update) error {
common.Log.Trace("Got telegraph link")
tgphUrl := TelegraphUrlRegex.FindString(update.EffectiveMessage.Text)
if tgphUrl == "" {
return dispatcher.ContinueGroups
}
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件..."), nil)
if err != nil {
common.Log.Errorf("回复失败: %s", err)
return dispatcher.EndGroups
}
user, err := dao.GetUserByChatID(update.GetUserChat().GetID())
if err != nil {
common.Log.Errorf("获取用户失败: %s", err)
ctx.Reply(update, ext.ReplyTextString("获取用户失败"), nil)
return dispatcher.EndGroups
}
storages := storage.GetUserStorages(user.ChatID)
if len(storages) == 0 {
ctx.Reply(update, ext.ReplyTextString("无可用的存储"), nil)
return dispatcher.EndGroups
}
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
fileName, err := url.PathUnescape(tgphPath)
if err != nil {
common.Log.Errorf("解析 Telegraph 路径失败: %s", err)
fileName = tgphPath
}
record := &dao.ReceivedFile{
Processing: false,
FileName: fileName,
ChatID: update.EffectiveChat().GetID(),
MessageID: update.EffectiveMessage.GetID(),
ReplyMessageID: replied.ID,
ReplyChatID: update.EffectiveChat().GetID(),
IsTelegraph: true,
TelegraphURL: tgphUrl,
}
if err := dao.SaveReceivedFile(record); err != nil {
common.Log.Errorf("保存接收的文件失败: %s", err)
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
Message: "无法保存文件: " + err.Error(),
ID: replied.ID,
})
return dispatcher.EndGroups
}
if !user.Silent || user.DefaultStorage == "" {
return ProvideSelectMessage(ctx, update, fileName, update.EffectiveChat().GetID(), update.EffectiveMessage.GetID(), replied.ID)
}
return HandleSilentAddTask(ctx, update, user, &types.Task{
Ctx: ctx,
Status: types.Pending,
StorageName: user.DefaultStorage,
UserID: user.ChatID,
ReplyMessageID: replied.ID,
ReplyChatID: update.GetUserChat().GetID(),
IsTelegraph: true,
TelegraphURL: tgphUrl,
})
}

View File

@@ -20,6 +20,11 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
common.Log.Panicf("创建正则表达式过滤器失败: %s", err)
}
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
telegraphUrlRegexFilter, err := filters.Message.Regex(TelegraphUrlRegexString)
if err != nil {
common.Log.Panicf("创建 Telegraph URL 正则表达式过滤器失败: %s", err)
}
dispatcher.AddHandler(handlers.NewMessage(telegraphUrlRegexFilter, handleTelegraph))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))

View File

@@ -200,7 +200,13 @@ func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileNa
}
func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) {
key := fmt.Sprintf("message:%d:%d", chatId, messageID)
common.Log.Debugf("Fetching message: %d", messageID)
var cachedMessage tg.Message
err := common.Cache.Get(key, &cachedMessage)
if err == nil {
return &cachedMessage, nil
}
messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}})
if err != nil {
return nil, err
@@ -213,16 +219,19 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
if !ok {
return nil, fmt.Errorf("unexpected message type: %T", msg)
}
if err := common.Cache.Set(key, tgMessage, 3600); err != nil {
common.Log.Errorf("Failed to cache message: %s", err)
}
return tgMessage, nil
}
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, file *types.File, chatID int64, fileMsgID, toEditMsgID int) error {
func ProvideSelectMessage(ctx *ext.Context, update *ext.Update, fileName string, chatID int64, fileMsgID, toEditMsgID int) error {
entityBuilder := entity.Builder{}
var entities []tg.MessageEntityClass
text := fmt.Sprintf("文件名: %s\n请选择存储位置", file.FileName)
text := fmt.Sprintf("文件名: %s\n请选择存储位置", fileName)
if err := styling.Perform(&entityBuilder,
styling.Plain("文件名: "),
styling.Code(file.FileName),
styling.Code(fileName),
styling.Plain("\n请选择存储位置"),
); err != nil {
common.Log.Errorf("Failed to build entity: %s", err)

View File

@@ -21,10 +21,12 @@ func initCache() {
gob.Register(types.File{})
gob.Register(tg.InputDocumentFileLocation{})
gob.Register(tg.InputPhotoFileLocation{})
gob.Register(tg.Message{})
gob.Register(tg.PeerUser{})
Cache = &CommonCache{cache: freecache.NewCache(10 * 1024 * 1024)}
}
func (c *CommonCache) Get(key string, value *types.File) error {
func (c *CommonCache) Get(key string, value any) error {
c.mu.RLock()
defer c.mu.RUnlock()
data, err := Cache.cache.Get([]byte(key))
@@ -39,7 +41,7 @@ func (c *CommonCache) Get(key string, value *types.File) error {
return nil
}
func (c *CommonCache) Set(key string, value *types.File, expireSeconds int) error {
func (c *CommonCache) Set(key string, value any, expireSeconds int) error {
c.mu.Lock()
defer c.mu.Unlock()
var buf bytes.Buffer

View File

@@ -2,12 +2,17 @@ package core
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"path"
"path/filepath"
"strings"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
@@ -25,7 +30,7 @@ func processPendingTask(task *types.Task) error {
}
if task.StoragePath == "" {
task.StoragePath = task.File.FileName
task.StoragePath = task.FileName()
}
taskStorage, err := storage.GetStorageByUserIDAndName(task.UserID, task.StorageName)
@@ -42,6 +47,10 @@ func processPendingTask(task *types.Task) error {
cancelCtx, cancel := context.WithCancel(ctx)
task.Cancel = cancel
if task.IsTelegraph {
return processTelegraph(ctx, cancelCtx, task, taskStorage)
}
if task.File.FileSize == 0 {
return processPhoto(task, taskStorage)
}
@@ -123,5 +132,103 @@ func processPendingTask(task *types.Task) error {
ID: task.ReplyMessageID,
})
return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath)
return saveFileWithRetry(cancelCtx, task.StoragePath, taskStorage, cacheDestPath)
}
func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *types.Task, taskStorage storage.Storage) error {
if bot.TelegraphClient == nil {
return fmt.Errorf("telegraph client is not initialized")
}
tgphUrl := task.TelegraphURL
tgphPath := strings.Split(tgphUrl, "/")[len(strings.Split(tgphUrl, "/"))-1]
if tgphUrl == "" || tgphPath == "" {
return fmt.Errorf("invalid telegraph url")
}
resultCh := make(chan error)
go func() {
page, err := bot.TelegraphClient.GetPage(tgphPath, true)
if err != nil {
resultCh <- fmt.Errorf("获取 telegraph 页面失败: %w", err)
return
}
imgs := make([]string, 0)
for _, element := range page.Content {
var node telegraph.NodeElement
data, err := json.Marshal(element)
if err != nil {
common.Log.Errorf("Failed to marshal element: %s", err)
continue
}
err = json.Unmarshal(data, &node)
if err != nil {
common.Log.Errorf("Failed to unmarshal element: %s", err)
continue
}
if node.Tag == "img" {
if src, ok := node.Attrs["src"]; ok {
imgs = append(imgs, src)
}
}
}
if len(imgs) == 0 {
resultCh <- fmt.Errorf("没有找到图片")
return
}
hc := bot.TelegraphClient.HttpClient
eg, ectx := errgroup.WithContext(cancelCtx)
eg.SetLimit(config.Cfg.Workers) // TODO: use a new config field for this
for i, img := range imgs {
eg.Go(func() error {
var lastErr error
for attempt := range config.Cfg.Retry {
if attempt > 0 {
retryDelay := time.Duration(attempt*attempt) * time.Second
select {
case <-ectx.Done():
return ectx.Err()
case <-time.After(retryDelay):
}
common.Log.Debugf("Retrying to download image %s (attempt %d)", img, attempt+1)
}
req, err := http.NewRequestWithContext(ectx, http.MethodGet, img, nil)
if err != nil {
lastErr = fmt.Errorf("创建请求失败: %w", err)
continue
}
resp, err := hc.Do(req)
if err != nil {
lastErr = fmt.Errorf("发送请求失败: %w", err)
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("请求图片失败: %s", resp.Status)
continue
}
targetPath := path.Join(task.StoragePath, fmt.Sprintf("%d%s", i+1, path.Ext(img)))
err = taskStorage.Save(ectx, resp.Body, targetPath)
if err != nil {
lastErr = fmt.Errorf("保存图片失败: %w", err)
continue
}
common.Log.Infof("Saved image: %s", targetPath)
return nil
}
return lastErr
})
}
if err := eg.Wait(); err != nil {
resultCh <- err
return
}
resultCh <- nil
}()
select {
case err := <-resultCh:
return err
case <-cancelCtx.Done():
return cancelCtx.Err()
}
}

View File

@@ -21,7 +21,7 @@ import (
"github.com/krau/SaveAny-Bot/types"
)
func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storage.Storage, cacheFilePath string) error {
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
for i := 0; i <= config.Cfg.Retry; i++ {
if err := ctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
@@ -30,7 +30,7 @@ func saveFileWithRetry(ctx context.Context, task *types.Task, taskStorage storag
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
if err := taskStorage.Save(ctx, file, task.StoragePath); err != nil {
if err := taskStorage.Save(ctx, file, storagePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}

View File

@@ -14,6 +14,8 @@ type ReceivedFile struct {
ReplyMessageID int
ReplyChatID int64
FileName string
IsTelegraph bool
TelegraphURL string
}
type User struct {

1
go.mod
View File

@@ -5,6 +5,7 @@ go 1.23.5
require (
github.com/blang/semver v3.5.1+incompatible
github.com/celestix/gotgproto v1.0.0-beta20.2
github.com/celestix/telegraph-go/v2 v2.0.4
github.com/gabriel-vasile/mimetype v1.4.8
github.com/gookit/slog v0.5.7
github.com/gotd/contrib v0.21.0

2
go.sum
View File

@@ -4,6 +4,8 @@ github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdn
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
github.com/celestix/gotgproto v1.0.0-beta20.2 h1:+WcsKdsyj4xy+TAV+4Sw6zp1xiQrIr4dMnM31+k8NYM=
github.com/celestix/gotgproto v1.0.0-beta20.2/go.mod h1:j42ZhBMUke6QyBLvCgx8tA+TL9L3+pq/Q46B+b5+3aU=
github.com/celestix/telegraph-go/v2 v2.0.4 h1:w8HWymJFhMSMPjdGoyTh3/NqE3eXAT1njTvelh0338k=
github.com/celestix/telegraph-go/v2 v2.0.4/go.mod h1:vu2LtqM7MgOAJ2LDF8XK27DWdd1QYLBfZGhalEh086Y=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=

View File

@@ -5,6 +5,8 @@ import (
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"strings"
"time"
"github.com/gotd/td/tg"
@@ -41,29 +43,46 @@ type Task struct {
Cancel context.CancelFunc
Error error
Status TaskStatus
File *File
StorageName string
StoragePath string
StartTime time.Time
File *File
FileMessageID int
FileChatID int64
IsTelegraph bool
TelegraphURL string
// to track the reply message
ReplyMessageID int
ReplyChatID int64
// to track the user
UserID int64
UserID int64
}
func (t Task) Key() string {
if t.IsTelegraph {
return hashStr(t.TelegraphURL)
}
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
}
func (t Task) String() string {
if t.IsTelegraph {
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
}
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
func (t Task) FileName() string {
if t.IsTelegraph {
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
if err != nil {
return tgphPath
}
return tgphPathUnescaped
}
return t.File.FileName
}

12
types/utils.go Normal file
View File

@@ -0,0 +1,12 @@
package types
import (
"crypto/md5"
"encoding/hex"
)
func hashStr(s string) string {
hash := md5.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}