perf: refactor file download to support multithreading

This commit is contained in:
krau
2025-02-21 13:49:15 +08:00
parent 8975589c43
commit ed21b65c98
4 changed files with 104 additions and 194 deletions

View File

@@ -2,8 +2,6 @@ package core
import (
"fmt"
"io"
"os"
"path/filepath"
"time"
@@ -50,31 +48,26 @@ func processPendingTask(task *types.Task) error {
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
}
barTotalCount := calculateBarTotalCount(task.File.FileSize)
text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0)
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
ID: task.ReplyMessageID,
})
progressCallback := buildProgressCallback(ctx, task, barTotalCount)
readCloser, err := NewTelegramReader(ctx, bot.Client, &task.File.Location,
0, task.File.FileSize-1, task.File.FileSize,
progressCallback, task.File.FileSize/100)
if err != nil {
return fmt.Errorf("创建下载失败: %w", err)
}
defer readCloser.Close()
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
dest, err := os.Create(cacheDestPath)
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
if err != nil {
return fmt.Errorf("创建文件失败: %w", err)
}
defer dest.Close()
task.StartTime = time.Now()
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
_, err = downloadBuider.Parallel(ctx, dest)
if err != nil {
return fmt.Errorf("下载文件失败: %w", err)
}
defer cleanCacheFile(cacheDestPath)
fixTaskFileExt(task, cacheDestPath)

9
core/downloader.go Normal file
View File

@@ -0,0 +1,9 @@
package core
import "github.com/gotd/td/telegram/downloader"
var Downloader *downloader.Downloader
func init() {
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
}

View File

@@ -1,154 +0,0 @@
package core
import (
"context"
"fmt"
"io"
"strings"
"github.com/celestix/gotgproto"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/config"
)
type telegramReader struct {
client *gotgproto.Client
location *tg.InputFileLocationClass
bytesread int64
chunkSize int64
copied int64
contentLength int64
start int64
end int64
next func() ([]byte, error)
progressCallback func(bytesRead, contentLength int64)
callbackInterval int64
lastProgress int64
buffer []byte
ctx context.Context
}
func (*telegramReader) Close() error {
return nil
}
func (r *telegramReader) Read(dst []byte) (n int, err error) {
if r.bytesread == r.contentLength {
return 0, io.EOF
}
if r.copied >= int64(len(r.buffer)) {
r.buffer, err = r.next()
if err != nil {
return 0, err
}
if len(r.buffer) == 0 {
r.next = r.partStream()
r.buffer, err = r.next()
if err != nil {
return 0, err
}
}
r.copied = 0
}
n = copy(dst, r.buffer[r.copied:])
r.copied += int64(n)
r.bytesread += int64(n)
if r.progressCallback != nil && (r.bytesread-r.lastProgress >= r.callbackInterval || r.bytesread == r.contentLength) {
r.progressCallback(r.bytesread, r.contentLength)
r.lastProgress = r.bytesread
}
return n, nil
}
func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) {
var lastError error
for i := 0; i < config.Cfg.Retry; i++ {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: *r.location,
}
res, err := r.client.API().UploadGetFile(r.ctx, req)
if err != nil {
if strings.Contains(err.Error(), tg.ErrTimeout) {
lastError = err
continue
}
return nil, err
}
switch result := res.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
return nil, lastError
}
func (r *telegramReader) partStream() func() ([]byte, error) {
start := r.start
end := r.end
offset := start - (start % r.chunkSize)
firstPartCut := start - offset
lastPartCut := (end % r.chunkSize) + 1
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
currentPart := 1
readData := func() ([]byte, error) {
if currentPart > partCount {
return make([]byte, 0), nil
}
res, err := r.chunk(offset, r.chunkSize)
if err != nil {
return nil, err
}
if len(res) == 0 {
return res, nil
} else if partCount == 1 {
res = res[firstPartCut:lastPartCut]
} else if currentPart == 1 {
res = res[firstPartCut:]
} else if currentPart == partCount {
res = res[:lastPartCut]
}
currentPart++
offset += r.chunkSize
return res, nil
}
return readData
}
func NewTelegramReader(
ctx context.Context,
client *gotgproto.Client,
location *tg.InputFileLocationClass,
start int64,
end int64,
contentLength int64,
progressCallback func(bytesRead, contentLength int64),
callbackInterval int64,
) (io.ReadCloser, error) {
r := &telegramReader{
ctx: ctx,
location: location,
client: client,
start: start,
end: end,
chunkSize: int64(1024 * 1024),
contentLength: contentLength,
progressCallback: progressCallback,
callbackInterval: callbackInterval,
}
r.next = r.partStream()
return r, nil
}

View File

@@ -59,18 +59,18 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin
return saveFileWithRetry(task, taskStorage, cachePath)
}
func getProgressBar(progress float64, totalCount int) string {
bar := ""
barSize := 100 / totalCount
for i := 0; i < totalCount; i++ {
if int(progress)/barSize > i {
bar += "█"
} else {
bar += "░"
}
}
return bar
}
// func getProgressBar(progress float64, updateCount int) string {
// bar := ""
// barSize := 100 / updateCount
// for i := 0; i < updateCount; i++ {
// if progress >= float64(barSize*(i+1)) {
// bar += "█"
// } else {
// bar += "░"
// }
// }
// return bar
// }
func cleanCacheFile(destPath string) {
if config.Cfg.Temp.CacheTTL > 0 {
@@ -82,16 +82,17 @@ func cleanCacheFile(destPath string) {
}
}
func calculateBarTotalCount(fileSize int64) int {
barTotalCount := 5
// 获取进度需要更新的次数
func getProgressUpdateCount(fileSize int64) int {
updateCount := 5
if fileSize > 1024*1024*1000 {
barTotalCount = 40
updateCount = 50
} else if fileSize > 1024*1024*500 {
barTotalCount = 20
updateCount = 20
} else if fileSize > 1024*1024*200 {
barTotalCount = 10
updateCount = 10
}
return barTotalCount
return updateCount
}
func getSpeed(bytesRead int64, startTime time.Time) string {
@@ -103,13 +104,12 @@ func getSpeed(bytesRead int64, startTime time.Time) string {
return fmt.Sprintf("%.2fMB/s", speed)
}
func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
entityBuilder := entity.Builder{}
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
task.FileName(),
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
getSpeed(bytesRead, startTime),
getProgressBar(progress, barTotalCount),
progress,
)
var entities []tg.MessageEntityClass
@@ -120,8 +120,8 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
styling.Plain("\n平均速度: "),
styling.Bold(getSpeed(bytesRead, task.StartTime)),
styling.Plain("\n当前进度:\n "),
styling.Code(fmt.Sprintf("[%s] %.2f%%", getProgressBar(progress, barTotalCount), progress)),
styling.Plain("\n当前进度: "),
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
); err != nil {
logger.L.Errorf("Failed to build entities: %s", err)
return text, entities
@@ -129,14 +129,15 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
return entityBuilder.Complete()
}
func buildProgressCallback(ctx *ext.Context, task *types.Task, barTotalCount int) func(bytesRead, contentLength int64) {
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
return func(bytesRead, contentLength int64) {
progress := float64(bytesRead) / float64(contentLength) * 100
logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress)
if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 {
progressInt := int(progress)
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
return
}
text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress)
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
Message: text,
Entities: entities,
@@ -156,3 +157,64 @@ func fixTaskFileExt(task *types.Task, localFilePath string) {
}
}
}
// TODO: configurable
func getTaskThreads(fileSize int64) int {
threads := 1
if fileSize > 1024*1024*100 {
threads = 4
} else if fileSize > 1024*1024*50 {
threads = 2
}
return threads
}
type TaskLocalFile struct {
file *os.File
size int64
done int64
progressCallback func(bytesRead, contentLength int64)
callbackTimes int64
nextCallbackAt int64
callbackInterval int64
}
func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
return t.file.Read(p)
}
func (t *TaskLocalFile) Close() error {
return t.file.Close()
}
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
n, err := t.file.WriteAt(p, off)
if err != nil {
return n, err
}
t.done += int64(n)
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
t.progressCallback(t.done, t.size)
t.nextCallbackAt += t.callbackInterval
}
return n, nil
}
func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
var callbackInterval int64
callbackInterval = fileSize / 100
if callbackInterval == 0 {
callbackInterval = 1
}
return &TaskLocalFile{
file: file,
size: fileSize,
progressCallback: progressCallback,
callbackTimes: 100,
nextCallbackAt: callbackInterval,
callbackInterval: callbackInterval,
}, nil
}