Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ffd9891a0 | ||
|
|
347a60f1f7 | ||
|
|
da69fe1354 |
32
bot/utils.go
32
bot/utils.go
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
@@ -277,6 +278,19 @@ func GenFileNameFromMessage(message tg.Message, file *types.File) string {
|
||||
if file.FileName != "" {
|
||||
return file.FileName
|
||||
}
|
||||
fileName := genFileNameFromMessageText(message, file)
|
||||
media, ok := message.GetMedia()
|
||||
if !ok {
|
||||
return fileName
|
||||
}
|
||||
ext, ok := extraMediaExt(media)
|
||||
if ok {
|
||||
return fileName + ext
|
||||
}
|
||||
return fileName
|
||||
}
|
||||
|
||||
func genFileNameFromMessageText(message tg.Message, file *types.File) string {
|
||||
text := strings.TrimSpace(message.GetMessage())
|
||||
if text == "" {
|
||||
return file.Hash()
|
||||
@@ -288,3 +302,21 @@ func GenFileNameFromMessage(message tg.Message, file *types.File) string {
|
||||
runes := []rune(text)
|
||||
return string(runes[:min(128, len(runes))])
|
||||
}
|
||||
|
||||
func extraMediaExt(media tg.MessageMediaClass) (string, bool) {
|
||||
switch media := media.(type) {
|
||||
case *tg.MessageMediaDocument:
|
||||
doc, ok := media.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
ext := mimetype.Lookup(doc.MimeType).Extension()
|
||||
if ext == "" {
|
||||
return "", false
|
||||
}
|
||||
return ext, true
|
||||
case *tg.MessageMediaPhoto:
|
||||
return ".jpg", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/gotd/td/telegram/downloader"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
@@ -13,6 +14,12 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
var Downloader *downloader.Downloader
|
||||
|
||||
func init() {
|
||||
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
|
||||
}
|
||||
|
||||
func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
|
||||
for {
|
||||
semaphore <- struct{}{}
|
||||
|
||||
@@ -188,6 +188,13 @@ func processTelegraph(extCtx *ext.Context, cancelCtx context.Context, task *type
|
||||
common.Log.Errorf("Failed to unmarshal element: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(node.Children) != 0 {
|
||||
for _, child := range node.Children {
|
||||
imgs = append(imgs, getNodeImages(child)...)
|
||||
}
|
||||
}
|
||||
|
||||
if node.Tag == "img" {
|
||||
if src, ok := node.Attrs["src"]; ok {
|
||||
imgs = append(imgs, src)
|
||||
|
||||
80
core/download_test.go
Normal file
80
core/download_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
)
|
||||
|
||||
func TestGetImgSrcs(t *testing.T) {
|
||||
complexStructure := telegraph.NodeElement{
|
||||
Tag: "div",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "figure",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image1.png",
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "p",
|
||||
Children: []telegraph.Node{
|
||||
"A text node",
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "figure",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image2.png",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image3.png",
|
||||
},
|
||||
},
|
||||
"text node",
|
||||
telegraph.NodeElement{
|
||||
Tag: "div",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "span",
|
||||
Children: []telegraph.Node{
|
||||
telegraph.NodeElement{
|
||||
Tag: "img",
|
||||
Attrs: map[string]string{
|
||||
"src": "https://example.com/image4.png",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
"https://example.com/image1.png",
|
||||
"https://example.com/image2.png",
|
||||
"https://example.com/image3.png",
|
||||
"https://example.com/image4.png",
|
||||
}
|
||||
|
||||
got := getNodeImages(complexStructure)
|
||||
|
||||
if !reflect.DeepEqual(expected, got) {
|
||||
t.Errorf("expected %v,got %v", expected, got)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package core
|
||||
|
||||
import "github.com/gotd/td/telegram/downloader"
|
||||
|
||||
var Downloader *downloader.Downloader
|
||||
|
||||
func init() {
|
||||
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package core
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/celestix/telegraph-go/v2"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
@@ -22,22 +24,33 @@ import (
|
||||
)
|
||||
|
||||
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
|
||||
file, err := os.Open(cacheFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
fileStat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file stat: %w", err)
|
||||
}
|
||||
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
|
||||
for i := 0; i <= config.Cfg.Retry; i++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if err := vctx.Err(); err != nil {
|
||||
return fmt.Errorf("context canceled while saving file: %w", err)
|
||||
}
|
||||
file, err := os.Open(cacheFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open cache file: %w", err)
|
||||
}
|
||||
if err := taskStorage.Save(ctx, file, storagePath); err != nil {
|
||||
defer file.Close()
|
||||
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
|
||||
if i == config.Cfg.Retry {
|
||||
return fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
common.Log.Errorf("Failed to save file: %s, retrying...", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context canceled during retry delay: %w", ctx.Err())
|
||||
case <-vctx.Done():
|
||||
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
|
||||
case <-time.After(time.Duration(i*500) * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
@@ -256,3 +269,27 @@ func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, co
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
func getNodeImages(node telegraph.Node) []string {
|
||||
var srcs []string
|
||||
|
||||
var nodeElement telegraph.NodeElement
|
||||
data, err := json.Marshal(node)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
err = json.Unmarshal(data, &nodeElement)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
|
||||
if nodeElement.Tag == "img" {
|
||||
if src, exists := nodeElement.Attrs["src"]; exists {
|
||||
srcs = append(srcs, src)
|
||||
}
|
||||
}
|
||||
for _, child := range nodeElement.Children {
|
||||
srcs = append(srcs, getNodeImages(child)...)
|
||||
}
|
||||
return srcs
|
||||
}
|
||||
|
||||
@@ -106,6 +106,12 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string)
|
||||
req.Header.Set("Authorization", a.token)
|
||||
req.Header.Set("File-Path", url.PathEscape(storagePath))
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
|
||||
length, ok := length.(int64)
|
||||
if ok {
|
||||
req.ContentLength = length
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/types"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@@ -38,6 +40,11 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body io.Read
|
||||
if c.Username != "" && c.Password != "" {
|
||||
req.SetBasicAuth(c.Username, c.Password)
|
||||
}
|
||||
if length := ctx.Value(types.ContextKeyContentLength); length != nil {
|
||||
if l, ok := length.(int64); ok {
|
||||
req.ContentLength = l
|
||||
}
|
||||
}
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
|
||||
82
types/task.go
Normal file
82
types/task.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
Error error
|
||||
Status TaskStatus
|
||||
StorageName string
|
||||
StoragePath string
|
||||
StartTime time.Time
|
||||
|
||||
File *File
|
||||
FileMessageID int
|
||||
FileChatID int64
|
||||
|
||||
IsTelegraph bool
|
||||
TelegraphURL string
|
||||
|
||||
// to track the reply message
|
||||
ReplyMessageID int
|
||||
ReplyChatID int64
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (t Task) Key() string {
|
||||
if t.IsTelegraph {
|
||||
return hashStr(t.TelegraphURL)
|
||||
}
|
||||
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
|
||||
}
|
||||
|
||||
func (t Task) String() string {
|
||||
if t.IsTelegraph {
|
||||
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
|
||||
}
|
||||
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
|
||||
}
|
||||
|
||||
func (t Task) FileName() string {
|
||||
if t.IsTelegraph {
|
||||
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
|
||||
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
|
||||
if err != nil {
|
||||
return tgphPath
|
||||
}
|
||||
return tgphPathUnescaped
|
||||
}
|
||||
return t.File.FileName
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Location tg.InputFileLocationClass
|
||||
FileSize int64
|
||||
FileName string
|
||||
}
|
||||
|
||||
func (f File) Hash() string {
|
||||
locationBytes := []byte(f.Location.String())
|
||||
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
|
||||
fileNameBytes := []byte(f.FileName)
|
||||
|
||||
structBytes := append(locationBytes, fileSizeBytes...)
|
||||
structBytes = append(structBytes, fileNameBytes...)
|
||||
|
||||
hash := md5.New()
|
||||
hash.Write(structBytes)
|
||||
hashBytes := hash.Sum(nil)
|
||||
|
||||
return hex.EncodeToString(hashBytes)
|
||||
}
|
||||
@@ -1,20 +1,8 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type TaskStatus string
|
||||
|
||||
var (
|
||||
const (
|
||||
Pending TaskStatus = "pending"
|
||||
Succeeded TaskStatus = "succeeded"
|
||||
Failed TaskStatus = "failed"
|
||||
@@ -23,7 +11,7 @@ var (
|
||||
|
||||
type StorageType string
|
||||
|
||||
var (
|
||||
const (
|
||||
StorageTypeLocal StorageType = "local"
|
||||
StorageTypeWebdav StorageType = "webdav"
|
||||
StorageTypeAlist StorageType = "alist"
|
||||
@@ -38,71 +26,8 @@ var StorageTypeDisplay = map[StorageType]string{
|
||||
StorageTypeMinio: "Minio",
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
Error error
|
||||
Status TaskStatus
|
||||
StorageName string
|
||||
StoragePath string
|
||||
StartTime time.Time
|
||||
type ContextKey string
|
||||
|
||||
File *File
|
||||
FileMessageID int
|
||||
FileChatID int64
|
||||
|
||||
IsTelegraph bool
|
||||
TelegraphURL string
|
||||
|
||||
// to track the reply message
|
||||
ReplyMessageID int
|
||||
ReplyChatID int64
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (t Task) Key() string {
|
||||
if t.IsTelegraph {
|
||||
return hashStr(t.TelegraphURL)
|
||||
}
|
||||
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
|
||||
}
|
||||
|
||||
func (t Task) String() string {
|
||||
if t.IsTelegraph {
|
||||
return fmt.Sprintf("[telegraph]:%s", t.TelegraphURL)
|
||||
}
|
||||
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
|
||||
}
|
||||
|
||||
func (t Task) FileName() string {
|
||||
if t.IsTelegraph {
|
||||
tgphPath := strings.Split(t.TelegraphURL, "/")[len(strings.Split(t.TelegraphURL, "/"))-1]
|
||||
tgphPathUnescaped, err := url.PathUnescape(tgphPath)
|
||||
if err != nil {
|
||||
return tgphPath
|
||||
}
|
||||
return tgphPathUnescaped
|
||||
}
|
||||
return t.File.FileName
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Location tg.InputFileLocationClass
|
||||
FileSize int64
|
||||
FileName string
|
||||
}
|
||||
|
||||
func (f File) Hash() string {
|
||||
locationBytes := []byte(f.Location.String())
|
||||
fileSizeBytes := []byte(fmt.Sprintf("%d", f.FileSize))
|
||||
fileNameBytes := []byte(f.FileName)
|
||||
|
||||
structBytes := append(locationBytes, fileSizeBytes...)
|
||||
structBytes = append(structBytes, fileNameBytes...)
|
||||
|
||||
hash := md5.New()
|
||||
hash.Write(structBytes)
|
||||
hashBytes := hash.Sum(nil)
|
||||
|
||||
return hex.EncodeToString(hashBytes)
|
||||
}
|
||||
const (
|
||||
ContextKeyContentLength ContextKey = "content-length"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user