mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-26 01:31:29 +08:00
perf: refactor file download to support multithreading
This commit is contained in:
@@ -2,8 +2,6 @@ package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
@@ -50,31 +48,26 @@ func processPendingTask(task *types.Task) error {
|
||||
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
|
||||
}
|
||||
|
||||
barTotalCount := calculateBarTotalCount(task.File.FileSize)
|
||||
text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0)
|
||||
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
ID: task.ReplyMessageID,
|
||||
})
|
||||
progressCallback := buildProgressCallback(ctx, task, barTotalCount)
|
||||
readCloser, err := NewTelegramReader(ctx, bot.Client, &task.File.Location,
|
||||
0, task.File.FileSize-1, task.File.FileSize,
|
||||
progressCallback, task.File.FileSize/100)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建下载失败: %w", err)
|
||||
}
|
||||
defer readCloser.Close()
|
||||
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
|
||||
|
||||
dest, err := os.Create(cacheDestPath)
|
||||
dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建文件失败: %w", err)
|
||||
}
|
||||
defer dest.Close()
|
||||
task.StartTime = time.Now()
|
||||
if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil {
|
||||
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
|
||||
_, err = downloadBuider.Parallel(ctx, dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("下载文件失败: %w", err)
|
||||
}
|
||||
|
||||
defer cleanCacheFile(cacheDestPath)
|
||||
|
||||
fixTaskFileExt(task, cacheDestPath)
|
||||
|
||||
9
core/downloader.go
Normal file
9
core/downloader.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package core
|
||||
|
||||
import "github.com/gotd/td/telegram/downloader"
|
||||
|
||||
var Downloader *downloader.Downloader
|
||||
|
||||
func init() {
|
||||
Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024)
|
||||
}
|
||||
154
core/reader.go
154
core/reader.go
@@ -1,154 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/celestix/gotgproto"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
)
|
||||
|
||||
type telegramReader struct {
|
||||
client *gotgproto.Client
|
||||
location *tg.InputFileLocationClass
|
||||
bytesread int64
|
||||
chunkSize int64
|
||||
copied int64
|
||||
contentLength int64
|
||||
start int64
|
||||
end int64
|
||||
next func() ([]byte, error)
|
||||
progressCallback func(bytesRead, contentLength int64)
|
||||
callbackInterval int64
|
||||
lastProgress int64
|
||||
buffer []byte
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (*telegramReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *telegramReader) Read(dst []byte) (n int, err error) {
|
||||
if r.bytesread == r.contentLength {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if r.copied >= int64(len(r.buffer)) {
|
||||
r.buffer, err = r.next()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(r.buffer) == 0 {
|
||||
r.next = r.partStream()
|
||||
r.buffer, err = r.next()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
}
|
||||
r.copied = 0
|
||||
}
|
||||
n = copy(dst, r.buffer[r.copied:])
|
||||
r.copied += int64(n)
|
||||
r.bytesread += int64(n)
|
||||
|
||||
if r.progressCallback != nil && (r.bytesread-r.lastProgress >= r.callbackInterval || r.bytesread == r.contentLength) {
|
||||
r.progressCallback(r.bytesread, r.contentLength)
|
||||
r.lastProgress = r.bytesread
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) {
|
||||
var lastError error
|
||||
for i := 0; i < config.Cfg.Retry; i++ {
|
||||
req := &tg.UploadGetFileRequest{
|
||||
Offset: offset,
|
||||
Limit: int(limit),
|
||||
Location: *r.location,
|
||||
}
|
||||
res, err := r.client.API().UploadGetFile(r.ctx, req)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), tg.ErrTimeout) {
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
switch result := res.(type) {
|
||||
case *tg.UploadFile:
|
||||
return result.Bytes, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected type %T", r)
|
||||
}
|
||||
}
|
||||
return nil, lastError
|
||||
}
|
||||
|
||||
func (r *telegramReader) partStream() func() ([]byte, error) {
|
||||
|
||||
start := r.start
|
||||
end := r.end
|
||||
offset := start - (start % r.chunkSize)
|
||||
|
||||
firstPartCut := start - offset
|
||||
lastPartCut := (end % r.chunkSize) + 1
|
||||
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
|
||||
currentPart := 1
|
||||
|
||||
readData := func() ([]byte, error) {
|
||||
if currentPart > partCount {
|
||||
return make([]byte, 0), nil
|
||||
}
|
||||
res, err := r.chunk(offset, r.chunkSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return res, nil
|
||||
} else if partCount == 1 {
|
||||
res = res[firstPartCut:lastPartCut]
|
||||
} else if currentPart == 1 {
|
||||
res = res[firstPartCut:]
|
||||
} else if currentPart == partCount {
|
||||
res = res[:lastPartCut]
|
||||
}
|
||||
|
||||
currentPart++
|
||||
offset += r.chunkSize
|
||||
return res, nil
|
||||
}
|
||||
return readData
|
||||
}
|
||||
|
||||
func NewTelegramReader(
|
||||
ctx context.Context,
|
||||
client *gotgproto.Client,
|
||||
location *tg.InputFileLocationClass,
|
||||
start int64,
|
||||
end int64,
|
||||
contentLength int64,
|
||||
progressCallback func(bytesRead, contentLength int64),
|
||||
callbackInterval int64,
|
||||
) (io.ReadCloser, error) {
|
||||
|
||||
r := &telegramReader{
|
||||
ctx: ctx,
|
||||
location: location,
|
||||
client: client,
|
||||
start: start,
|
||||
end: end,
|
||||
chunkSize: int64(1024 * 1024),
|
||||
contentLength: contentLength,
|
||||
progressCallback: progressCallback,
|
||||
callbackInterval: callbackInterval,
|
||||
}
|
||||
|
||||
r.next = r.partStream()
|
||||
return r, nil
|
||||
}
|
||||
114
core/utils.go
114
core/utils.go
@@ -59,18 +59,18 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin
|
||||
return saveFileWithRetry(task, taskStorage, cachePath)
|
||||
}
|
||||
|
||||
func getProgressBar(progress float64, totalCount int) string {
|
||||
bar := ""
|
||||
barSize := 100 / totalCount
|
||||
for i := 0; i < totalCount; i++ {
|
||||
if int(progress)/barSize > i {
|
||||
bar += "█"
|
||||
} else {
|
||||
bar += "░"
|
||||
}
|
||||
}
|
||||
return bar
|
||||
}
|
||||
// func getProgressBar(progress float64, updateCount int) string {
|
||||
// bar := ""
|
||||
// barSize := 100 / updateCount
|
||||
// for i := 0; i < updateCount; i++ {
|
||||
// if progress >= float64(barSize*(i+1)) {
|
||||
// bar += "█"
|
||||
// } else {
|
||||
// bar += "░"
|
||||
// }
|
||||
// }
|
||||
// return bar
|
||||
// }
|
||||
|
||||
func cleanCacheFile(destPath string) {
|
||||
if config.Cfg.Temp.CacheTTL > 0 {
|
||||
@@ -82,16 +82,17 @@ func cleanCacheFile(destPath string) {
|
||||
}
|
||||
}
|
||||
|
||||
func calculateBarTotalCount(fileSize int64) int {
|
||||
barTotalCount := 5
|
||||
// 获取进度需要更新的次数
|
||||
func getProgressUpdateCount(fileSize int64) int {
|
||||
updateCount := 5
|
||||
if fileSize > 1024*1024*1000 {
|
||||
barTotalCount = 40
|
||||
updateCount = 50
|
||||
} else if fileSize > 1024*1024*500 {
|
||||
barTotalCount = 20
|
||||
updateCount = 20
|
||||
} else if fileSize > 1024*1024*200 {
|
||||
barTotalCount = 10
|
||||
updateCount = 10
|
||||
}
|
||||
return barTotalCount
|
||||
return updateCount
|
||||
}
|
||||
|
||||
func getSpeed(bytesRead int64, startTime time.Time) string {
|
||||
@@ -103,13 +104,12 @@ func getSpeed(bytesRead int64, startTime time.Time) string {
|
||||
return fmt.Sprintf("%.2fMB/s", speed)
|
||||
}
|
||||
|
||||
func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
|
||||
func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) {
|
||||
entityBuilder := entity.Builder{}
|
||||
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%",
|
||||
text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%",
|
||||
task.FileName(),
|
||||
fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath),
|
||||
getSpeed(bytesRead, startTime),
|
||||
getProgressBar(progress, barTotalCount),
|
||||
progress,
|
||||
)
|
||||
var entities []tg.MessageEntityClass
|
||||
@@ -120,8 +120,8 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)),
|
||||
styling.Plain("\n平均速度: "),
|
||||
styling.Bold(getSpeed(bytesRead, task.StartTime)),
|
||||
styling.Plain("\n当前进度:\n "),
|
||||
styling.Code(fmt.Sprintf("[%s] %.2f%%", getProgressBar(progress, barTotalCount), progress)),
|
||||
styling.Plain("\n当前进度: "),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", progress)),
|
||||
); err != nil {
|
||||
logger.L.Errorf("Failed to build entities: %s", err)
|
||||
return text, entities
|
||||
@@ -129,14 +129,15 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i
|
||||
return entityBuilder.Complete()
|
||||
}
|
||||
|
||||
func buildProgressCallback(ctx *ext.Context, task *types.Task, barTotalCount int) func(bytesRead, contentLength int64) {
|
||||
func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) {
|
||||
return func(bytesRead, contentLength int64) {
|
||||
progress := float64(bytesRead) / float64(contentLength) * 100
|
||||
logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress)
|
||||
if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 {
|
||||
progressInt := int(progress)
|
||||
if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 {
|
||||
return
|
||||
}
|
||||
text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress)
|
||||
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
|
||||
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
|
||||
Message: text,
|
||||
Entities: entities,
|
||||
@@ -156,3 +157,64 @@ func fixTaskFileExt(task *types.Task, localFilePath string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: configurable
|
||||
func getTaskThreads(fileSize int64) int {
|
||||
threads := 1
|
||||
if fileSize > 1024*1024*100 {
|
||||
threads = 4
|
||||
} else if fileSize > 1024*1024*50 {
|
||||
threads = 2
|
||||
}
|
||||
return threads
|
||||
}
|
||||
|
||||
type TaskLocalFile struct {
|
||||
file *os.File
|
||||
size int64
|
||||
done int64
|
||||
progressCallback func(bytesRead, contentLength int64)
|
||||
callbackTimes int64
|
||||
nextCallbackAt int64
|
||||
callbackInterval int64
|
||||
}
|
||||
|
||||
func (t *TaskLocalFile) Read(p []byte) (n int, err error) {
|
||||
return t.file.Read(p)
|
||||
}
|
||||
|
||||
func (t *TaskLocalFile) Close() error {
|
||||
return t.file.Close()
|
||||
}
|
||||
func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) {
|
||||
n, err := t.file.WriteAt(p, off)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
t.done += int64(n)
|
||||
if t.progressCallback != nil && t.done >= t.nextCallbackAt {
|
||||
t.progressCallback(t.done, t.size)
|
||||
t.nextCallbackAt += t.callbackInterval
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) {
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
var callbackInterval int64
|
||||
callbackInterval = fileSize / 100
|
||||
if callbackInterval == 0 {
|
||||
callbackInterval = 1
|
||||
}
|
||||
return &TaskLocalFile{
|
||||
file: file,
|
||||
size: fileSize,
|
||||
progressCallback: progressCallback,
|
||||
callbackTimes: 100,
|
||||
nextCallbackAt: callbackInterval,
|
||||
callbackInterval: callbackInterval,
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user