feat: show progress for stream mode

This commit is contained in:
krau
2025-03-01 12:22:50 +08:00
parent 802c908384
commit 6fbb4609f9
2 changed files with 46 additions and 34 deletions

View File

@@ -8,8 +8,6 @@ import (
"github.com/celestix/gotgproto/ext"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/bot"
"github.com/krau/SaveAny-Bot/config"
@@ -56,28 +54,12 @@ func processPendingTask(task *types.Task) error {
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
// TODO: show progress for stream storage
taskStreamStorage, isStreamStorage := taskStorage.(storage.StreamStorage)
if config.Cfg.Stream {
if !isStreamStorage {
logger.L.Warnf("存储 %s 不支持流式上传", taskStorage.Name())
} else {
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务 (流式)\n文件名: %s\n保存路径: %s",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
)
var entities []tg.MessageEntityClass
if err := styling.Perform(&entityBuilder,
styling.Plain("正在处理下载任务 (流式)\n文件名: "),
styling.Code(task.FileName()),
styling.Plain("\n保存路径: "),
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
); err != nil {
logger.L.Errorf("Failed to build entities: %s", err)
} else {
text, entities = entityBuilder.Complete()
}
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
@@ -89,7 +71,13 @@ func processPendingTask(task *types.Task) error {
return fmt.Errorf("创建上传流失败: %w", err)
}
defer uploadStream.Close()
_, err = downloadBuider.Stream(cancelCtx, uploadStream)
task.StartTime = time.Now()
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
progressStream := NewProgressStream(uploadStream, task.File.FileSize, progressCallback)
_, err = downloadBuider.Stream(cancelCtx, progressStream)
if err != nil {
return fmt.Errorf("下载文件失败: %w", err)
}
@@ -128,5 +116,4 @@ func processPendingTask(task *types.Task) error {
})
return saveFileWithRetry(cancelCtx, task, taskStorage, cacheDestPath)
}

View File

@@ -3,6 +3,7 @@ package core
import (
"context"
"fmt"
"io"
"os"
"path"
"time"
@@ -68,19 +69,6 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin
return saveFileWithRetry(task.Ctx, task, taskStorage, cachePath)
}
// func getProgressBar(progress float64, updateCount int) string {
// bar := ""
// barSize := 100 / updateCount
// for i := 0; i < updateCount; i++ {
// if progress >= float64(barSize*(i+1)) {
// bar += "█"
// } else {
// bar += "░"
// }
// }
// return bar
// }
func cleanCacheFile(destPath string) {
if config.Cfg.Temp.CacheTTL > 0 {
common.RmFileAfter(destPath, time.Duration(config.Cfg.Temp.CacheTTL)*time.Second)
@@ -233,3 +221,40 @@ func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(byt
callbackInterval: callbackInterval,
}, nil
}
type ProgressStream struct {
writer io.Writer
size int64
done int64
callback func(bytesRead, contentLength int64)
nextAt int64
interval int64
}
func (ps *ProgressStream) Write(p []byte) (n int, err error) {
n, err = ps.writer.Write(p)
if err != nil {
return n, err
}
ps.done += int64(n)
if ps.callback != nil && ps.done >= ps.nextAt {
ps.callback(ps.done, ps.size)
ps.nextAt += ps.interval
}
return n, nil
}
func NewProgressStream(writer io.Writer, size int64, callback func(bytesRead, contentLength int64)) *ProgressStream {
var interval int64
interval = size / 100
if interval == 0 {
interval = 1
}
return &ProgressStream{
writer: writer,
size: size,
callback: callback,
nextAt: interval,
interval: interval,
}
}