Compare commits

...

9 Commits

14 changed files with 326 additions and 136 deletions

View File

@@ -1,6 +1,8 @@
package bot package bot
import ( import (
"fmt"
"github.com/celestix/gotgproto/dispatcher" "github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext" "github.com/celestix/gotgproto/ext"
"github.com/krau/SaveAny-Bot/common" "github.com/krau/SaveAny-Bot/common"
@@ -17,6 +19,7 @@ func start(ctx *ext.Context, update *ext.Update) error {
const helpText string = ` const helpText string = `
Save Any Bot - 转存你的 Telegram 文件 Save Any Bot - 转存你的 Telegram 文件
版本: %s , 提交: %s
命令: 命令:
/start - 开始使用 /start - 开始使用
/help - 显示帮助 /help - 显示帮助
@@ -32,6 +35,6 @@ Save Any Bot - 转存你的 Telegram 文件
` `
func help(ctx *ext.Context, update *ext.Update) error { func help(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString(helpText), nil) ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf(helpText, common.Version, common.GitCommit[:7])), nil)
return dispatcher.EndGroups return dispatcher.EndGroups
} }

View File

@@ -9,6 +9,7 @@ import (
"github.com/celestix/gotgproto/dispatcher" "github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext" "github.com/celestix/gotgproto/ext"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message/entity" "github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling" "github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg" "github.com/gotd/td/tg"
@@ -200,13 +201,7 @@ func FileFromMessage(ctx *ext.Context, chatID int64, messageID int, customFileNa
} }
func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) { func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, error) {
key := fmt.Sprintf("message:%d:%d", chatId, messageID)
common.Log.Debugf("Fetching message: %d", messageID) common.Log.Debugf("Fetching message: %d", messageID)
var cachedMessage tg.Message
err := common.Cache.Get(key, &cachedMessage)
if err == nil {
return &cachedMessage, nil
}
messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}}) messages, err := ctx.GetMessages(chatId, []tg.InputMessageClass{&tg.InputMessageID{ID: messageID}})
if err != nil { if err != nil {
return nil, err return nil, err
@@ -219,9 +214,6 @@ func GetTGMessage(ctx *ext.Context, chatId int64, messageID int) (*tg.Message, e
if !ok { if !ok {
return nil, fmt.Errorf("unexpected message type: %T", msg) return nil, fmt.Errorf("unexpected message type: %T", msg)
} }
if err := common.Cache.Set(key, tgMessage, 3600); err != nil {
common.Log.Errorf("Failed to cache message: %s", err)
}
return tgMessage, nil return tgMessage, nil
} }
@@ -286,6 +278,19 @@ func GenFileNameFromMessage(message tg.Message, file *types.File) string {
if file.FileName != "" { if file.FileName != "" {
return file.FileName return file.FileName
} }
fileName := genFileNameFromMessageText(message, file)
media, ok := message.GetMedia()
if !ok {
return fileName
}
ext, ok := extraMediaExt(media)
if ok {
return fileName + ext
}
return fileName
}
func genFileNameFromMessageText(message tg.Message, file *types.File) string {
text := strings.TrimSpace(message.GetMessage()) text := strings.TrimSpace(message.GetMessage())
if text == "" { if text == "" {
return file.Hash() return file.Hash()
@@ -297,3 +302,21 @@ func GenFileNameFromMessage(message tg.Message, file *types.File) string {
runes := []rune(text) runes := []rune(text)
return string(runes[:min(128, len(runes))]) return string(runes[:min(128, len(runes))])
} }
func extraMediaExt(media tg.MessageMediaClass) (string, bool) {
switch media := media.(type) {
case *tg.MessageMediaDocument:
doc, ok := media.Document.AsNotEmpty()
if !ok {
return "", false
}
ext := mimetype.Lookup(doc.MimeType).Extension()
if ext == "" {
return "", false
}
return ext, true
case *tg.MessageMediaPhoto:
return ".jpg", true
}
return "", false
}

View File

@@ -21,8 +21,6 @@ func initCache() {
gob.Register(types.File{}) gob.Register(types.File{})
gob.Register(tg.InputDocumentFileLocation{}) gob.Register(tg.InputDocumentFileLocation{})
gob.Register(tg.InputPhotoFileLocation{}) gob.Register(tg.InputPhotoFileLocation{})
gob.Register(tg.Message{})
gob.Register(tg.PeerUser{})
Cache = &CommonCache{cache: freecache.NewCache(10 * 1024 * 1024)} Cache = &CommonCache{cache: freecache.NewCache(10 * 1024 * 1024)}
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/celestix/gotgproto/ext" "github.com/celestix/gotgproto/ext"
"github.com/gotd/td/telegram/downloader"
"github.com/gotd/td/tg" "github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common" "github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
@@ -13,6 +14,12 @@ import (
"github.com/krau/SaveAny-Bot/types" "github.com/krau/SaveAny-Bot/types"
) )
var Downloader *downloader.Downloader
func init() {
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
}
func worker(queue *queue.TaskQueue, semaphore chan struct{}) { func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
for { for {
semaphore <- struct{}{} semaphore <- struct{}{}

View File

@@ -59,41 +59,50 @@ func processPendingTask(task *types.Task) error {
downloadBuilder := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) downloadBuilder := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
notsupportStreamStorage, notsupportStream := taskStorage.(storage.StorageNotSupportStream)
cancelMarkUp := getCancelTaskMarkup(task)
if config.Cfg.Stream { if config.Cfg.Stream {
if !notsupportStream {
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) pr, pw := io.Pipe()
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ defer pr.Close()
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
pr, pw := io.Pipe() task.StartTime = time.Now()
defer pr.Close() progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
task.StartTime = time.Now() progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback)
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
progressStream := NewProgressStream(pw, task.File.FileSize, progressCallback) eg, uploadCtx := errgroup.WithContext(cancelCtx)
eg, uploadCtx := errgroup.WithContext(cancelCtx) eg.Go(func() error {
return taskStorage.Save(uploadCtx, pr, task.StoragePath)
eg.Go(func() error { })
return taskStorage.Save(uploadCtx, pr, task.StoragePath) eg.Go(func() error {
}) _, err := downloadBuilder.Stream(uploadCtx, progressStream)
eg.Go(func() error { if closeErr := pw.CloseWithError(err); closeErr != nil {
_, err := downloadBuilder.Stream(uploadCtx, progressStream) common.Log.Errorf("Failed to close pipe writer: %v", closeErr)
if closeErr := pw.CloseWithError(err); closeErr != nil { }
common.Log.Errorf("Failed to close pipe writer: %v", closeErr) return err
})
if err := eg.Wait(); err != nil {
return err
} }
return err
})
if err := eg.Wait(); err != nil {
return err
}
return nil return nil
}
common.Log.Warnf("存储 %s 不支持流式传输: %s", task.StorageName, notsupportStreamStorage.NotSupportStream())
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: fmt.Sprintf("存储 %s 不支持流式传输: %s\n正在使用普通下载...", task.StorageName, notsupportStreamStorage.NotSupportStream()),
ID: task.ReplyMessageID,
ReplyMarkup: cancelMarkUp,
})
} }
cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName()) cacheDestPath := filepath.Join(config.Cfg.Temp.BasePath, task.FileName())
@@ -110,7 +119,7 @@ func processPendingTask(task *types.Task) error {
Message: text, Message: text,
Entities: entities, Entities: entities,
ID: task.ReplyMessageID, ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task), ReplyMarkup: cancelMarkUp,
}) })
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
@@ -188,6 +197,13 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type
common.Log.Errorf("Failed to unmarshal element: %s", err) common.Log.Errorf("Failed to unmarshal element: %s", err)
continue continue
} }
if len(node.Children) != 0 {
for _, child := range node.Children {
imgs = append(imgs, getNodeImages(child)...)
}
}
if node.Tag == "img" { if node.Tag == "img" {
if src, ok := node.Attrs["src"]; ok { if src, ok := node.Attrs["src"]; ok {
imgs = append(imgs, src) imgs = append(imgs, src)

80
core/download_test.go Normal file
View File

@@ -0,0 +1,80 @@
package core
import (
"reflect"
"testing"
"github.com/celestix/telegraph-go/v2"
)
func TestGetImgSrcs(t *testing.T) {
complexStructure := telegraph.NodeElement{
Tag: "div",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "figure",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image1.png",
},
},
telegraph.NodeElement{
Tag: "p",
Children: []telegraph.Node{
"A text node",
},
},
telegraph.NodeElement{
Tag: "figure",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image2.png",
},
},
},
},
},
},
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image3.png",
},
},
"text node",
telegraph.NodeElement{
Tag: "div",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "span",
Children: []telegraph.Node{
telegraph.NodeElement{
Tag: "img",
Attrs: map[string]string{
"src": "https://example.com/image4.png",
},
},
},
},
},
},
},
}
expected := []string{
"https://example.com/image1.png",
"https://example.com/image2.png",
"https://example.com/image3.png",
"https://example.com/image4.png",
}
got := getNodeImages(complexStructure)
if !reflect.DeepEqual(expected, got) {
t.Errorf("expected %vgot %v", expected, got)
}
}

View File

@@ -1,9 +0,0 @@
package core
import "github.com/gotd/td/telegram/downloader"
var Downloader *downloader.Downloader
func init() {
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
}

View File

@@ -3,6 +3,7 @@ package core
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -10,6 +11,7 @@ import (
"time" "time"
"github.com/celestix/gotgproto/ext" "github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/gabriel-vasile/mimetype" "github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message/entity" "github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling" "github.com/gotd/td/telegram/message/styling"
@@ -22,22 +24,33 @@ import (
) )
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error { func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
fileStat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
for i := 0; i <= config.Cfg.Retry; i++ { for i := 0; i <= config.Cfg.Retry; i++ {
if err := ctx.Err(); err != nil { if err := vctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err) return fmt.Errorf("context canceled while saving file: %w", err)
} }
file, err := os.Open(cacheFilePath) file, err := os.Open(cacheFilePath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open cache file: %w", err) return fmt.Errorf("failed to open cache file: %w", err)
} }
if err := taskStorage.Save(ctx, file, storagePath); err != nil { defer file.Close()
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
if i == config.Cfg.Retry { if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err) return fmt.Errorf("failed to save file: %w", err)
} }
common.Log.Errorf("Failed to save file: %s, retrying...", err) common.Log.Errorf("Failed to save file: %s, retrying...", err)
select { select {
case <-ctx.Done(): case <-vctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", ctx.Err()) return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond): case <-time.After(time.Duration(i*500) * time.Millisecond):
} }
continue continue
@@ -256,3 +269,27 @@ func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, co
interval: interval, interval: interval,
} }
} }
func getNodeImages(node telegraph.Node) []string {
var srcs []string
var nodeElement telegraph.NodeElement
data, err := json.Marshal(node)
if err != nil {
return srcs
}
err = json.Unmarshal(data, &nodeElement)
if err != nil {
return srcs
}
if nodeElement.Tag == "img" {
if src, exists := nodeElement.Attrs["src"]; exists {
srcs = append(srcs, src)
}
}
for _, child := range nodeElement.Children {
srcs = append(srcs, getNodeImages(child)...)
}
return srcs
}

View File

@@ -4,10 +4,14 @@
Bot 接受两种消息: 文件和链接. Bot 接受两种消息: 文件和链接.
目前, 链接仅支持公开频道 (具有用户名) 的链接, 例如: `https://t.me/acherkrau/1097`. 支持以下链接:
1. 公开频道 (具有用户名) 的消息链接, 例如: `https://t.me/acherkrau/1097`.
**即使频道禁止了转发和保存, Bot 依然可以下载其文件.** **即使频道禁止了转发和保存, Bot 依然可以下载其文件.**
2. Telegra.ph 的文章链接, Bot 将下载其中的所有图片
## 静默模式 (silent) ## 静默模式 (silent)
使用 `/silent` 命令可以开关静默模式. 使用 `/silent` 命令可以开关静默模式.
@@ -32,4 +36,6 @@ Bot 接受两种消息: 文件和链接.
- 网络不稳定时, 任务失败率高. - 网络不稳定时, 任务失败率高.
- 无法在中间层对文件进行处理, 例如自动文件类型识别. - 无法在中间层对文件进行处理, 例如自动文件类型识别.
虽然目前 Bot 适配的所有存储端 (Alist, 本地磁盘, Webdav) 都支持 Stream 模式, 但今后可能会有不支持的存储端, 此时即使开启 Stream 模式, Bot 也会自动切换到普通模式. **不支持** Stream 模式的存储端:
- alist

View File

@@ -106,6 +106,12 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string)
req.Header.Set("Authorization", a.token) req.Header.Set("Authorization", a.token)
req.Header.Set("File-Path", url.PathEscape(storagePath)) req.Header.Set("File-Path", url.PathEscape(storagePath))
req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Type", "application/octet-stream")
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
length, ok := length.(int64)
if ok {
req.ContentLength = length
}
}
resp, err := a.client.Do(req) resp, err := a.client.Do(req)
if err != nil { if err != nil {
@@ -134,6 +140,10 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string)
return nil return nil
} }
func (a *Alist) NotSupportStream() string {
return "Alist does not support chunked transfer encoding"
}
func (a *Alist) JoinStoragePath(task types.Task) string { func (a *Alist) JoinStoragePath(task types.Task) string {
return path.Join(a.config.BasePath, task.StoragePath) return path.Join(a.config.BasePath, task.StoragePath)
} }

View File

@@ -23,6 +23,11 @@ type Storage interface {
Save(ctx context.Context, reader io.Reader, storagePath string) error Save(ctx context.Context, reader io.Reader, storagePath string) error
} }
type StorageNotSupportStream interface {
Storage
NotSupportStream() string
}
var Storages = make(map[string]Storage) var Storages = make(map[string]Storage)
var UserStorages = make(map[int64][]Storage) var UserStorages = make(map[int64][]Storage)

View File

@@ -6,6 +6,8 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/krau/SaveAny-Bot/types"
) )
type Client struct { type Client struct {
@@ -38,6 +40,11 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body io.Read
if c.Username != "" && c.Password != "" { if c.Username != "" && c.Password != "" {
req.SetBasicAuth(c.Username, c.Password) req.SetBasicAuth(c.Username, c.Password)
} }
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
if l, ok := length.(int64); ok {
req.ContentLength = l
}
}
return c.httpClient.Do(req) return c.httpClient.Do(req)
} }

82
types/task.go Normal file
View File

@@ -0,0 +1,82 @@
package types
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"strings"
"time"
"github.com/gotd/td/tg"
)
type Task struct {
Ctx context.Context
Cancel context.CancelFunc
Error error
Status TaskStatus
StorageName string
StoragePath string
StartTime time.Time
File *File
FileMessageID int
FileChatID int64
IsTelegraph bool
TelegraphURL string
// to track the reply message
ReplyMessageID int
ReplyChatID int64
UserID int64
}
func (t Task) Key() string {
if t.IsTelegraph {
return hashStr(t.TelegraphURL)
}
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
}
func (t Task) String() string {
if t.IsTelegraph {
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
}
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
func (t Task) FileName() string {
if t.IsTelegraph {
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
if err != nil {
return tgphPath
}
return tgphPathUnescaped
}
return t.File.FileName
}
type File struct {
Location tg.InputFileLocationClass
FileSize int64
FileName string
}
func (f File) Hash() string {
locationBytes := []byte(f.Location.String())
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
fileNameBytes := []byte(f.FileName)
structBytes := append(locationBytes, fileSizeBytes...)
structBytes = append(structBytes, fileNameBytes...)
hash := md5.New()
hash.Write(structBytes)
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}

View File

@@ -1,20 +1,8 @@
package types package types
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"strings"
"time"
"github.com/gotd/td/tg"
)
type TaskStatus string type TaskStatus string
var ( const (
Pending TaskStatus = "pending" Pending TaskStatus = "pending"
Succeeded TaskStatus = "succeeded" Succeeded TaskStatus = "succeeded"
Failed TaskStatus = "failed" Failed TaskStatus = "failed"
@@ -23,7 +11,7 @@ var (
type StorageType string type StorageType string
var ( const (
StorageTypeLocal StorageType = "local" StorageTypeLocal StorageType = "local"
StorageTypeWebdav StorageType = "webdav" StorageTypeWebdav StorageType = "webdav"
StorageTypeAlist StorageType = "alist" StorageTypeAlist StorageType = "alist"
@@ -38,71 +26,8 @@ var StorageTypeDisplay = map[StorageType]string{
StorageTypeMinio: "Minio", StorageTypeMinio: "Minio",
} }
type Task struct { type ContextKey string
Ctx context.Context
Cancel context.CancelFunc
Error error
Status TaskStatus
StorageName string
StoragePath string
StartTime time.Time
File *File const (
FileMessageID int ContextKeyContentLength ContextKey = "content-length"
FileChatID int64 )
IsTelegraph bool
TelegraphURL string
// to track the reply message
ReplyMessageID int
ReplyChatID int64
UserID int64
}
func (t Task) Key() string {
if t.IsTelegraph {
return hashStr(t.TelegraphURL)
}
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
}
func (t Task) String() string {
if t.IsTelegraph {
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
}
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
}
func (t Task) FileName() string {
if t.IsTelegraph {
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
if err != nil {
return tgphPath
}
return tgphPathUnescaped
}
return t.File.FileName
}
type File struct {
Location tg.InputFileLocationClass
FileSize int64
FileName string
}
func (f File) Hash() string {
locationBytes := []byte(f.Location.String())
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
fileNameBytes := []byte(f.FileName)
structBytes := append(locationBytes, fileSizeBytes...)
structBytes = append(structBytes, fileNameBytes...)
hash := md5.New()
hash.Write(structBytes)
hashBytes := hash.Sum(nil)
return hex.EncodeToString(hashBytes)
}