refactor: refactor task logic for better scalability (#76)
* refactor: a big refactor. wip * refactor: port handle file * refactor: place all handlers * fix: task info nil pointer * feat: enhance task progress tracking and context management * feat: cancel task * feat: stream mode * feat: silent mode * feat: dir cmd * refactor: remove unused old file * feat: rule cmd * feat: handle silent mode * feat: batch task * fix: batch task progress and temp file cleanup * refactor: update file creation and cleanup methods for better resource management * feat: add save command with silent mode handling * feat: message link * feat: update message prompts to include file count in storage selection * feat: slient save links * refactor: reduce dup code * feat: rule type * feat: chose dir * feat: refactor file handling and storage rules, improve error handling and logging * feat: rule mode * feat: telegraph pics * fix: tphpics nil pointer and inaccurate dirpath * feat: silent save telegraph * feat: add suffix to avoid file overwrite * feat: new storage telegram * chore: tidy go mod
This commit is contained in:
33
common/utils/dlutil/dl.go
Normal file
33
common/utils/dlutil/dl.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package dlutil
|
||||
|
||||
import "time"
|
||||
|
||||
var threadsLevels = []struct {
|
||||
threads int
|
||||
size int64
|
||||
}{
|
||||
{1, 10 << 20},
|
||||
{2, 50 << 20},
|
||||
{4, 200 << 20},
|
||||
{8, 500 << 20},
|
||||
}
|
||||
|
||||
func BestThreads(size int64, max int) int {
|
||||
for _, thread := range threadsLevels {
|
||||
if size < thread.size {
|
||||
return min(thread.threads, max)
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func GetSpeed(downloaded int64, startTime time.Time) float64 {
|
||||
if startTime.IsZero() {
|
||||
return 0
|
||||
}
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
if elapsed <= 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(downloaded) / elapsed
|
||||
}
|
||||
57
common/utils/fsutil/fs.go
Normal file
57
common/utils/fsutil/fs.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package fsutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
)
|
||||
|
||||
// 删除文件夹内的所有文件和子目录, 但不删除文件夹本身
|
||||
func RemoveAllInDir(dirPath string) error {
|
||||
entries, err := os.ReadDir(dirPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
entryPath := filepath.Join(dirPath, entry.Name())
|
||||
if err := os.RemoveAll(entryPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DetectFileExt(fp string) string {
|
||||
mt, err := mimetype.DetectFile(fp)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return mt.Extension()
|
||||
}
|
||||
|
||||
type File struct {
|
||||
*os.File
|
||||
}
|
||||
|
||||
func (f *File) Remove() error {
|
||||
return os.Remove(f.Name())
|
||||
}
|
||||
|
||||
func (f *File) CloseAndRemove() error {
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Remove()
|
||||
}
|
||||
|
||||
func CreateFile(fp string) (*File, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(fp), os.ModePerm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file, err := os.Create(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &File{File: file}, nil
|
||||
}
|
||||
49
common/utils/ioutil/writer.go
Normal file
49
common/utils/ioutil/writer.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package ioutil
|
||||
|
||||
import "io"
|
||||
|
||||
type ProgressWriterAt struct {
|
||||
wrAt io.WriterAt
|
||||
onWrite func(n int)
|
||||
}
|
||||
|
||||
func (p *ProgressWriterAt) WriteAt(buf []byte, off int64) (n int, err error) {
|
||||
n, err = p.wrAt.WriteAt(buf, off)
|
||||
if n > 0 {
|
||||
p.onWrite(n)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func NewProgressWriterAt(
|
||||
wrAt io.WriterAt,
|
||||
onWrite func(n int),
|
||||
) *ProgressWriterAt {
|
||||
return &ProgressWriterAt{
|
||||
wrAt: wrAt,
|
||||
onWrite: onWrite,
|
||||
}
|
||||
}
|
||||
|
||||
type ProgressWriter struct {
|
||||
wr io.Writer
|
||||
onWrite func(n int)
|
||||
}
|
||||
|
||||
func (p *ProgressWriter) Write(buf []byte) (n int, err error) {
|
||||
n, err = p.wr.Write(buf)
|
||||
if n > 0 {
|
||||
p.onWrite(n)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func NewProgressWriter(
|
||||
wr io.Writer,
|
||||
onWrite func(n int),
|
||||
) *ProgressWriter {
|
||||
return &ProgressWriter{
|
||||
wr: wr,
|
||||
onWrite: onWrite,
|
||||
}
|
||||
}
|
||||
50
common/utils/strutil/string.go
Normal file
50
common/utils/strutil/string.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package strutil
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
)
|
||||
|
||||
func HashString(s string) string {
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(s))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
var TagRe = regexp.MustCompile(`(?:^|[\p{Zs}\s.,!?(){}[\]<>\"\',。!?():;、])#([\p{L}\d_]+)`)
|
||||
|
||||
func ExtractTagsFromText(text string) []string {
|
||||
matches := TagRe.FindAllStringSubmatch(text, -1)
|
||||
tags := make([]string, 0)
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
tags = append(tags, match[1])
|
||||
}
|
||||
}
|
||||
return slice.Compact(tags)
|
||||
}
|
||||
|
||||
func ParseIntStrRange(input string, sep string) (int64, int64, error) {
|
||||
parts := strings.Split(input, sep)
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, fmt.Errorf("invalid range format: %s", input)
|
||||
}
|
||||
min, err := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid minimum value: %s", parts[0])
|
||||
}
|
||||
max, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid maximum value: %s", parts[1])
|
||||
}
|
||||
if min > max {
|
||||
min, max = max, min
|
||||
}
|
||||
return min, max, nil
|
||||
}
|
||||
22
common/utils/tgutil/context.go
Normal file
22
common/utils/tgutil/context.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package tgutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
)
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
var extKey = contextKey{}
|
||||
|
||||
func ExtFromContext(ctx context.Context) *ext.Context {
|
||||
if extCtx, ok := ctx.Value(extKey).(*ext.Context); ok {
|
||||
return extCtx
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExtWithContext(ctx context.Context, extCtx *ext.Context) context.Context {
|
||||
return context.WithValue(ctx, extKey, extCtx)
|
||||
}
|
||||
183
common/utils/tgutil/message.go
Normal file
183
common/utils/tgutil/message.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package tgutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/duke-git/lancet/v2/maputil"
|
||||
"github.com/duke-git/lancet/v2/mathutil"
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
lcstrutil "github.com/duke-git/lancet/v2/strutil"
|
||||
"github.com/duke-git/lancet/v2/validator"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/cache"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/strutil"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
func GenFileNameFromMessage(message tg.Message) string {
|
||||
ext := func(media tg.MessageMediaClass) string {
|
||||
switch media := media.(type) {
|
||||
case *tg.MessageMediaDocument:
|
||||
doc, ok := media.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
ext := mimetype.Lookup(doc.MimeType).Extension()
|
||||
if ext == "" {
|
||||
return ""
|
||||
}
|
||||
return ext
|
||||
case *tg.MessageMediaPhoto:
|
||||
return ".jpg"
|
||||
}
|
||||
return ""
|
||||
}(message.Media)
|
||||
text := strings.TrimSpace(message.GetMessage())
|
||||
if text == "" {
|
||||
return fmt.Sprintf("%d_%s%s", message.GetID(), xid.New().String(), ext)
|
||||
}
|
||||
filename := func() string {
|
||||
tags := strutil.ExtractTagsFromText(text)
|
||||
if len(tags) > 0 {
|
||||
tagStrRunes := make([]rune, 0, 64)
|
||||
for i, tag := range tags {
|
||||
if i > 0 {
|
||||
tagStrRunes = append(tagStrRunes, '_')
|
||||
}
|
||||
tagStrRunes = append(tagStrRunes, []rune(tag)...)
|
||||
if len(tagStrRunes) >= 64 {
|
||||
break
|
||||
}
|
||||
}
|
||||
tagStr := string(tagStrRunes)
|
||||
return fmt.Sprintf("%s_%s", tagStr, strconv.Itoa(message.GetID()))
|
||||
}
|
||||
text = lcstrutil.Substring(strings.Map(func(r rune) rune {
|
||||
if r < 0x20 || r == 0x7F {
|
||||
return '_'
|
||||
}
|
||||
switch r {
|
||||
// invalid characters
|
||||
case '/', '\\',
|
||||
':', '*', '?', '"', '<', '>', '|':
|
||||
return '_'
|
||||
// empty
|
||||
case ' ', '\t', '\r', '\n':
|
||||
return '_'
|
||||
}
|
||||
if validator.IsPrintable(string(r)) {
|
||||
return r
|
||||
}
|
||||
return '_'
|
||||
}, text), 0, 64)
|
||||
text = strings.Join(strings.FieldsFunc(text, func(r rune) bool {
|
||||
return r == '_' || r == ' '
|
||||
}), "_")
|
||||
return text
|
||||
}()
|
||||
|
||||
if filename == "" {
|
||||
filename = fmt.Sprintf("%d_%s", message.GetID(), xid.New().String())
|
||||
}
|
||||
return filename + ext
|
||||
}
|
||||
|
||||
func BuildCancelButton(taskID string) tg.KeyboardButtonClass {
|
||||
return &tg.KeyboardButtonCallback{
|
||||
Text: "取消任务",
|
||||
Data: fmt.Appendf(nil, "cancel %s", taskID),
|
||||
}
|
||||
}
|
||||
|
||||
func InputMessageClassSliceFromInt(ids []int) []tg.InputMessageClass {
|
||||
result := make([]tg.InputMessageClass, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
result = append(result, &tg.InputMessageID{
|
||||
ID: id,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func GetMessagesRange(ctx *ext.Context, chatID int64, minId, maxId int) ([]*tg.Message, error) {
|
||||
if minId > maxId {
|
||||
return nil, fmt.Errorf("minId (%d) cannot be greater than maxId (%d)", minId, maxId)
|
||||
}
|
||||
total := maxId - minId + 1
|
||||
msgIds := mathutil.Range(minId, total)
|
||||
toFetchIds := make([]int, 0, total)
|
||||
cached := make(map[int]*tg.Message, total)
|
||||
for _, id := range msgIds {
|
||||
if msg, ok := cache.Get[*tg.Message](fmt.Sprintf("tgmsg:%d:%d:%d", ctx.Self.ID, chatID, id)); ok {
|
||||
cached[id] = msg
|
||||
} else {
|
||||
toFetchIds = append(toFetchIds, id)
|
||||
}
|
||||
}
|
||||
if len(toFetchIds) == 0 {
|
||||
return maputil.Values(cached), nil
|
||||
}
|
||||
|
||||
result := make([]*tg.Message, 0, total)
|
||||
chunks := slice.Chunk(toFetchIds, 100)
|
||||
for _, chunk := range chunks {
|
||||
msgs, err := ctx.GetMessages(chatID, InputMessageClassSliceFromInt(chunk))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(msgs) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, msg := range msgs {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
tgMessage, ok := msg.(*tg.Message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if tgMessage.GetID() < minId || tgMessage.GetID() > maxId {
|
||||
continue
|
||||
}
|
||||
result = append(result, tgMessage)
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range result {
|
||||
cache.Set(fmt.Sprintf("tgmsg:%d:%d:%d", ctx.Self.ID, chatID, msg.GetID()), msg)
|
||||
}
|
||||
for _, msg := range cached {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func GetMessageByID(ctx *ext.Context, chatID int64, msgID int) (*tg.Message, error) {
|
||||
key := fmt.Sprintf("tgmsg:%d:%d:%d", ctx.Self.ID, chatID, msgID)
|
||||
if msg, ok := cache.Get[*tg.Message](key); ok {
|
||||
return msg, nil
|
||||
}
|
||||
msgs, err := ctx.GetMessages(chatID, []tg.InputMessageClass{
|
||||
&tg.InputMessageID{ID: msgID},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get message by ID: %w", err)
|
||||
}
|
||||
if len(msgs) == 0 {
|
||||
return nil, fmt.Errorf("message not found: chatID=%d, msgID=%d", chatID, msgID)
|
||||
}
|
||||
msg := msgs[0]
|
||||
tgm, ok := msg.(*tg.Message)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected message type: %T", msg)
|
||||
}
|
||||
cache.Set(key, tgm)
|
||||
return tgm, nil
|
||||
}
|
||||
119
common/utils/tgutil/resolve.go
Normal file
119
common/utils/tgutil/resolve.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package tgutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/duke-git/lancet/v2/validator"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
func ParseChatID(ctx *ext.Context, idOrUsername string) (int64, error) {
|
||||
idOrUsername = strings.TrimPrefix(idOrUsername, "@")
|
||||
if validator.IsIntStr(idOrUsername) {
|
||||
chatID, err := strconv.Atoi(idOrUsername)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(chatID), nil
|
||||
}
|
||||
username := idOrUsername
|
||||
peer := ctx.PeerStorage.GetPeerByUsername(username)
|
||||
if peer != nil && peer.ID != 0 {
|
||||
return peer.ID, nil
|
||||
}
|
||||
chat, err := ctx.ResolveUsername(username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if chat == nil {
|
||||
return 0, fmt.Errorf("no chat found for username: %s", idOrUsername)
|
||||
}
|
||||
chatID := chat.GetID()
|
||||
if chatID == 0 {
|
||||
return 0, fmt.Errorf("chat ID is zero for username: %s", idOrUsername)
|
||||
}
|
||||
return chatID, nil
|
||||
}
|
||||
|
||||
// return: ChatID, MessageID, error
|
||||
func ParseMessageLink(ctx *ext.Context, link string) (int64, int, error) {
|
||||
u, err := url.Parse(link)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
paths := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
|
||||
if cmt := u.Query().Get("comment"); cmt != "" {
|
||||
// 频道评论的消息链接
|
||||
// https://t.me/acherkrau/123?comment=2
|
||||
chid, err := ParseChatID(ctx, paths[0])
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err)
|
||||
}
|
||||
chatfull, err := ctx.GetChat(chid)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to get chat: %w", err)
|
||||
}
|
||||
chfull, ok := chatfull.(*tg.ChannelFull)
|
||||
if !ok {
|
||||
return 0, 0, fmt.Errorf("chat is not a channel: %s", chatfull.TypeName())
|
||||
}
|
||||
linkChatId, ok := chfull.GetLinkedChatID()
|
||||
if !ok {
|
||||
return 0, 0, fmt.Errorf("channel has no linked chat")
|
||||
}
|
||||
msgID, err := strconv.Atoi(cmt)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse comment ID: %w", err)
|
||||
}
|
||||
return linkChatId, msgID, nil
|
||||
}
|
||||
|
||||
switch len(paths) {
|
||||
case 2: // https://t.me/acherkrau/123
|
||||
chatID, err := ParseChatID(ctx, paths[0])
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err)
|
||||
}
|
||||
msgID, err := strconv.Atoi(paths[1])
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse message ID: %w", err)
|
||||
}
|
||||
return chatID, msgID, nil
|
||||
case 3:
|
||||
// https://t.me/c/123456789/123
|
||||
// https://t.me/acherkrau/123/456 , 456: message thread ID
|
||||
chatPart, msgPart := paths[1], paths[2]
|
||||
if paths[0] != "c" {
|
||||
chatPart = paths[0]
|
||||
}
|
||||
chatID, err := ParseChatID(ctx, chatPart)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err)
|
||||
}
|
||||
msgID, err := strconv.Atoi(msgPart)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse message ID: %w", err)
|
||||
}
|
||||
return chatID, msgID, nil
|
||||
case 4:
|
||||
// https://t.me/c/123456789/111/456 111: topic id
|
||||
if paths[0] != "c" {
|
||||
return 0, 0, fmt.Errorf("invalid message link format: %s", link)
|
||||
}
|
||||
chatID, err := ParseChatID(ctx, paths[1])
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse chat ID: %w", err)
|
||||
}
|
||||
msgID, err := strconv.Atoi(paths[3])
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to parse message ID: %w", err)
|
||||
}
|
||||
return chatID, msgID, nil
|
||||
}
|
||||
return 0, 0, fmt.Errorf("invalid message link format: %s", link)
|
||||
}
|
||||
51
common/utils/tphutil/tph.go
Normal file
51
common/utils/tphutil/tph.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package tphutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
||||
)
|
||||
|
||||
var tphClient *telegraph.Client
|
||||
|
||||
func DefaultClient() *telegraph.Client {
|
||||
if tphClient != nil {
|
||||
return tphClient
|
||||
}
|
||||
if config.Cfg.Telegram.Proxy.Enable && config.Cfg.Telegram.Proxy.URL != "" {
|
||||
proxyUrl := config.Cfg.Telegram.Proxy.URL
|
||||
var err error
|
||||
tphClient, err = telegraph.NewClientWithProxy(proxyUrl)
|
||||
if err != nil {
|
||||
tphClient = telegraph.NewClient()
|
||||
}
|
||||
} else {
|
||||
tphClient = telegraph.NewClient()
|
||||
}
|
||||
return tphClient
|
||||
}
|
||||
|
||||
func GetNodeImages(node telegraph.Node) []string {
|
||||
var srcs []string
|
||||
|
||||
var nodeElement telegraph.NodeElement
|
||||
data, err := json.Marshal(node)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
err = json.Unmarshal(data, &nodeElement)
|
||||
if err != nil {
|
||||
return srcs
|
||||
}
|
||||
|
||||
if nodeElement.Tag == "img" {
|
||||
if src, exists := nodeElement.Attrs["src"]; exists {
|
||||
srcs = append(srcs, src)
|
||||
}
|
||||
}
|
||||
for _, child := range nodeElement.Children {
|
||||
srcs = append(srcs, GetNodeImages(child)...)
|
||||
}
|
||||
return srcs
|
||||
}
|
||||
Reference in New Issue
Block a user