Files
SaveAny-Bot/core/utils.go
2025-04-12 16:27:23 +08:00

299 lines
8.1 KiB
Go

package core
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"path"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/celestix/telegraph-go/v2"
"github.com/gabriel-vasile/mimetype"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/common"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/storage"
"github.com/krau/SaveAny-Bot/types"
)
func saveFileWithRetry(ctx context.Context, storagePath string, taskStorage storage.Storage, cacheFilePath string) error {
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
fileStat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, types.ContextKeyContentLength, fileStat.Size())
for i := 0; i <= config.Cfg.Retry; i++ {
if err := vctx.Err(); err != nil {
return fmt.Errorf("context canceled while saving file: %w", err)
}
file, err := os.Open(cacheFilePath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
if err := taskStorage.Save(vctx, file, storagePath); err != nil {
if i == config.Cfg.Retry {
return fmt.Errorf("failed to save file: %w", err)
}
common.Log.Errorf("Failed to save file: %s, retrying...", err)
select {
case <-vctx.Done():
return fmt.Errorf("context canceled during retry delay: %w", vctx.Err())
case <-time.After(time.Duration(i*500) * time.Millisecond):
}
continue
}
return nil
}
return nil
}
func processPhoto(task *types.Task, taskStorage storage.Storage) error {
res, err := bot.Client.API().UploadGetFile(task.Ctx, &tg.UploadGetFileRequest{
Location: task.File.Location,
Offset: 0,
Limit: 1024 * 1024,
})
if err != nil {
return fmt.Errorf("failed to get file: %w", err)
}
result, ok := res.(*tg.UploadFile)
if !ok {
return fmt.Errorf("unexpected type %T", res)
}
common.Log.Infof("Downloaded photo: %s", task.FileName())
return taskStorage.Save(task.Ctx, bytes.NewReader(result.Bytes), task.StoragePath)
}
func cleanCacheFile(destPath string) {
if config.Cfg.Temp.CacheTTL > 0 {
common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
} else {
if err := os.Remove(destPath); err != nil {
common.Log.Errorf("Failed to purge file: %s", err)
}
}
}
// 获取进度需要更新的次数
func getProgressUpdateCount(fileSize int64) int {
updateCount := 5
if fileSize > 1024*1024*1000 {
updateCount = 50
} else if fileSize > 1024*1024*500 {
updateCount = 20
} else if fileSize > 1024*1024*200 {
updateCount = 10
}
return updateCount
}
func getSpeed(bytesRead int64, startTime time.Time) string {
if startTime.IsZero() {
return "0MB/s"
}
elapsed := time.Since(startTime)
speed := float64(bytesRead) / 1024 / 1024 / elapsed.Seconds()
return fmt.Sprintf("%.2fMB/s", speed)
}
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
getSpeed(bytesRead, startTime),
progress,
)
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在处理下载任务\n文件名: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
); err != nil {
common.Log.Errorf("Failed to build entities: %s", err)
return text, entities
}
return entityBuilder.Complete()
}
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
return func(bytesRead, contentLength int64) {
progress := float64(bytesRead) / float64(contentLength) * 100
common.Log.Tracef("Downloading %s: %.2f%%", task.String(), progress)
progressInt := int(progress)
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
return
}
if task.ReplyMessageID == 0 {
return
}
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
ReplyMarkup: getCancelTaskMarkup(task),
})
}
}
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
return &tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
}
}
func fixTaskFileExt(task *types.Task, localFilePath string) {
if path.Ext(task.FileName()) == "" {
mimeType, err := mimetype.DetectFile(localFilePath)
if err != nil {
common.Log.Errorf("Failed to detect mime type: %s", err)
} else {
task.File.FileName = fmt.Sprintf("%s%s", task.FileName(), mimeType.Extension())
task.StoragePath = fmt.Sprintf("%s%s", task.StoragePath, mimeType.Extension())
}
}
}
func getTaskThreads(fileSize int64) int {
threads := 1
if fileSize > 1024*1024*100 {
threads = config.Cfg.Threads
} else if fileSize > 1024*1024*50 {
threads = config.Cfg.Threads / 2
}
return threads
}
type TaskLocalFile struct {
file *os.File
size int64
done int64
progressCallback func(bytesRead, contentLength int64)
callbackTimes int64
nextCallbackAt int64
callbackInterval int64
}
func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
return t.file.Read(p)
}
func (t *TaskLocalFile) Close() error {
return t.file.Close()
}
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
n, err := t.file.WriteAt(p, off)
if err != nil {
return n, err
}
t.done += int64(n)
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
t.progressCallback(t.done, t.size)
t.nextCallbackAt += t.callbackInterval
}
return n, nil
}
func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
var callbackInterval int64
callbackInterval = fileSize / 100
if callbackInterval == 0 {
callbackInterval = 1
}
return &TaskLocalFile{
file: file,
size: fileSize,
progressCallback: progressCallback,
callbackTimes: 100,
nextCallbackAt: callbackInterval,
callbackInterval: callbackInterval,
}, nil
}
type ProgressStream struct {
writer io.Writer
size int64
done int64
callback func(bytesRead, contentLength int64)
nextAt int64
interval int64
}
func (ps *ProgressStream) Write(p []byte) (n int, err error) {
n, err = ps.writer.Write(p)
if err != nil {
return n, err
}
ps.done += int64(n)
if ps.callback != nil && ps.done >= ps.nextAt {
ps.callback(ps.done, ps.size)
ps.nextAt += ps.interval
}
return n, nil
}
func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream {
var interval int64
interval = size / 100
if interval == 0 {
interval = 1
}
return &ProgressStream{
writer: writer,
size: size,
callback: callback,
nextAt: interval,
interval: interval,
}
}
func getNodeImages(node telegraph.Node) []string {
var srcs []string
var nodeElement telegraph.NodeElement
data, err := json.Marshal(node)
if err != nil {
return srcs
}
err = json.Unmarshal(data, &nodeElement)
if err != nil {
return srcs
}
if nodeElement.Tag == "img" {
if src, exists := nodeElement.Attrs["src"]; exists {
srcs = append(srcs, src)
}
}
for _, child := range nodeElement.Children {
srcs = append(srcs, getNodeImages(child)...)
}
return srcs
}