diff --git a/core/download.go b/core/download.go index 54385fa..960bfd1 100644 --- a/core/download.go +++ b/core/download.go @@ -2,8 +2,6 @@ package core import ( "fmt" - "io" - "os" "path/filepath" "time" @@ -50,31 +48,26 @@ func processPendingTask(task *types.Task) error { return fmt.Errorf("context is not *ext.Context: %T", task.Ctx) } - barTotalCount := calculateBarTotalCount(task.File.FileSize) - text, entities := buildProgressMessageEntity(task, barTotalCount, 0, task.StartTime, 0) + text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ Message: text, Entities: entities, ID: task.ReplyMessageID, }) - progressCallback := buildProgressCallback(ctx, task, barTotalCount) - readCloser, err := NewTelegramReader(ctx, bot.Client, &task.File.Location, - 0, task.File.FileSize-1, task.File.FileSize, - progressCallback, task.File.FileSize/100) - if err != nil { - return fmt.Errorf("创建下载失败: %w", err) - } - defer readCloser.Close() + progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize)) - dest, err := os.Create(cacheDestPath) + dest, err := NewTaskLocalFile(cacheDestPath, task.File.FileSize, progressCallback) if err != nil { return fmt.Errorf("创建文件失败: %w", err) } defer dest.Close() task.StartTime = time.Now() - if _, err := io.CopyN(dest, readCloser, task.File.FileSize); err != nil { + downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize)) + _, err = downloadBuider.Parallel(ctx, dest) + if err != nil { return fmt.Errorf("下载文件失败: %w", err) } + defer cleanCacheFile(cacheDestPath) fixTaskFileExt(task, cacheDestPath) diff --git a/core/downloader.go b/core/downloader.go new file mode 100644 index 0000000..f9e10a4 --- /dev/null +++ b/core/downloader.go @@ -0,0 +1,9 @@ +package core + +import "github.com/gotd/td/telegram/downloader" + +var Downloader *downloader.Downloader + +func init() { + Downloader = downloader.NewDownloader().WithPartSize(1024 * 1024) +} diff --git a/core/reader.go b/core/reader.go deleted file mode 100644 index 4199ed7..0000000 --- a/core/reader.go +++ /dev/null @@ -1,154 +0,0 @@ -package core - -import ( - "context" - "fmt" - "io" - "strings" - - "github.com/celestix/gotgproto" - "github.com/gotd/td/tg" - "github.com/krau/SaveAny-Bot/config" -) - -type telegramReader struct { - client *gotgproto.Client - location *tg.InputFileLocationClass - bytesread int64 - chunkSize int64 - copied int64 - contentLength int64 - start int64 - end int64 - next func() ([]byte, error) - progressCallback func(bytesRead, contentLength int64) - callbackInterval int64 - lastProgress int64 - buffer []byte - ctx context.Context -} - -func (*telegramReader) Close() error { - return nil -} - -func (r *telegramReader) Read(dst []byte) (n int, err error) { - if r.bytesread == r.contentLength { - return 0, io.EOF - } - - if r.copied >= int64(len(r.buffer)) { - r.buffer, err = r.next() - if err != nil { - return 0, err - } - if len(r.buffer) == 0 { - r.next = r.partStream() - r.buffer, err = r.next() - if err != nil { - return 0, err - } - - } - r.copied = 0 - } - n = copy(dst, r.buffer[r.copied:]) - r.copied += int64(n) - r.bytesread += int64(n) - - if r.progressCallback != nil && (r.bytesread-r.lastProgress >= r.callbackInterval || r.bytesread == r.contentLength) { - r.progressCallback(r.bytesread, r.contentLength) - r.lastProgress = r.bytesread - } - - return n, nil -} - -func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) { - var lastError error - for i := 0; i < config.Cfg.Retry; i++ { - req := &tg.UploadGetFileRequest{ - Offset: offset, - Limit: int(limit), - Location: *r.location, - } - res, err := r.client.API().UploadGetFile(r.ctx, req) - if err != nil { - if strings.Contains(err.Error(), tg.ErrTimeout) { - lastError = err - continue - } - return nil, err - } - switch result := res.(type) { - case *tg.UploadFile: - return result.Bytes, nil - default: - return nil, fmt.Errorf("unexpected type %T", r) - } - } - return nil, lastError -} - -func (r *telegramReader) partStream() func() ([]byte, error) { - - start := r.start - end := r.end - offset := start - (start % r.chunkSize) - - firstPartCut := start - offset - lastPartCut := (end % r.chunkSize) + 1 - partCount := int((end - offset + r.chunkSize) / r.chunkSize) - currentPart := 1 - - readData := func() ([]byte, error) { - if currentPart > partCount { - return make([]byte, 0), nil - } - res, err := r.chunk(offset, r.chunkSize) - if err != nil { - return nil, err - } - if len(res) == 0 { - return res, nil - } else if partCount == 1 { - res = res[firstPartCut:lastPartCut] - } else if currentPart == 1 { - res = res[firstPartCut:] - } else if currentPart == partCount { - res = res[:lastPartCut] - } - - currentPart++ - offset += r.chunkSize - return res, nil - } - return readData -} - -func NewTelegramReader( - ctx context.Context, - client *gotgproto.Client, - location *tg.InputFileLocationClass, - start int64, - end int64, - contentLength int64, - progressCallback func(bytesRead, contentLength int64), - callbackInterval int64, -) (io.ReadCloser, error) { - - r := &telegramReader{ - ctx: ctx, - location: location, - client: client, - start: start, - end: end, - chunkSize: int64(1024 * 1024), - contentLength: contentLength, - progressCallback: progressCallback, - callbackInterval: callbackInterval, - } - - r.next = r.partStream() - return r, nil -} diff --git a/core/utils.go b/core/utils.go index dc0d018..2d7f369 100644 --- a/core/utils.go +++ b/core/utils.go @@ -59,18 +59,18 @@ func processPhoto(task *types.Task, taskStorage storage.Storage, cachePath strin return saveFileWithRetry(task, taskStorage, cachePath) } -func getProgressBar(progress float64, totalCount int) string { - bar := "" - barSize := 100 / totalCount - for i := 0; i < totalCount; i++ { - if int(progress)/barSize > i { - bar += "█" - } else { - bar += "░" - } - } - return bar -} +// func getProgressBar(progress float64, updateCount int) string { +// bar := "" +// barSize := 100 / updateCount +// for i := 0; i < updateCount; i++ { +// if progress >= float64(barSize*(i+1)) { +// bar += "█" +// } else { +// bar += "░" +// } +// } +// return bar +// } func cleanCacheFile(destPath string) { if config.Cfg.Temp.CacheTTL > 0 { @@ -82,16 +82,17 @@ func cleanCacheFile(destPath string) { } } -func calculateBarTotalCount(fileSize int64) int { - barTotalCount := 5 +// 获取进度需要更新的次数 +func getProgressUpdateCount(fileSize int64) int { + updateCount := 5 if fileSize > 1024*1024*1000 { - barTotalCount = 40 + updateCount = 50 } else if fileSize > 1024*1024*500 { - barTotalCount = 20 + updateCount = 20 } else if fileSize > 1024*1024*200 { - barTotalCount = 10 + updateCount = 10 } - return barTotalCount + return updateCount } func getSpeed(bytesRead int64, startTime time.Time) string { @@ -103,13 +104,12 @@ func getSpeed(bytesRead int64, startTime time.Time) string { return fmt.Sprintf("%.2fMB/s", speed) } -func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) { +func buildProgressMessageEntity(task *types.Task, bytesRead int64, startTime time.Time, progress float64) (string, []tg.MessageEntityClass) { entityBuilder := entity.Builder{} - text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: [%s] %.2f%%", + text := fmt.Sprintf("正在处理下载任务\n文件名: %s\n保存路径: %s\n平均速度: %s\n当前进度: %.2f%%", task.FileName(), fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath), getSpeed(bytesRead, startTime), - getProgressBar(progress, barTotalCount), progress, ) var entities []tg.MessageEntityClass @@ -120,8 +120,8 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i styling.Code(fmt.Sprintf("[%s]:%s", task.StorageName, task.StoragePath)), styling.Plain("\n平均速度: "), styling.Bold(getSpeed(bytesRead, task.StartTime)), - styling.Plain("\n当前进度:\n "), - styling.Code(fmt.Sprintf("[%s] %.2f%%", getProgressBar(progress, barTotalCount), progress)), + styling.Plain("\n当前进度: "), + styling.Bold(fmt.Sprintf("%.2f%%", progress)), ); err != nil { logger.L.Errorf("Failed to build entities: %s", err) return text, entities @@ -129,14 +129,15 @@ func buildProgressMessageEntity(task *types.Task, barTotalCount int, bytesRead i return entityBuilder.Complete() } -func buildProgressCallback(ctx *ext.Context, task *types.Task, barTotalCount int) func(bytesRead, contentLength int64) { +func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int) func(bytesRead, contentLength int64) { return func(bytesRead, contentLength int64) { progress := float64(bytesRead) / float64(contentLength) * 100 logger.L.Tracef("Downloading %s: %.2f%%", task.String(), progress) - if task.File.FileSize < 1024*1024*50 || int(progress)%(100/barTotalCount) != 0 { + progressInt := int(progress) + if task.File.FileSize < 1024*1024*50 || progressInt == 0 || progressInt%int(100/updateCount) != 0 { return } - text, entities := buildProgressMessageEntity(task, barTotalCount, bytesRead, task.StartTime, progress) + text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress) ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{ Message: text, Entities: entities, @@ -156,3 +157,64 @@ func fixTaskFileExt(task *types.Task, localFilePath string) { } } } + +// TODO: configurable +func getTaskThreads(fileSize int64) int { + threads := 1 + if fileSize > 1024*1024*100 { + threads = 4 + } else if fileSize > 1024*1024*50 { + threads = 2 + } + return threads +} + +type TaskLocalFile struct { + file *os.File + size int64 + done int64 + progressCallback func(bytesRead, contentLength int64) + callbackTimes int64 + nextCallbackAt int64 + callbackInterval int64 +} + +func (t *TaskLocalFile) Read(p []byte) (n int, err error) { + return t.file.Read(p) +} + +func (t *TaskLocalFile) Close() error { + return t.file.Close() +} +func (t *TaskLocalFile) WriteAt(p []byte, off int64) (int, error) { + n, err := t.file.WriteAt(p, off) + if err != nil { + return n, err + } + t.done += int64(n) + if t.progressCallback != nil && t.done >= t.nextCallbackAt { + t.progressCallback(t.done, t.size) + t.nextCallbackAt += t.callbackInterval + } + return n, nil +} + +func NewTaskLocalFile(filePath string, fileSize int64, progressCallback func(bytesRead, contentLength int64)) (*TaskLocalFile, error) { + file, err := os.Create(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + var callbackInterval int64 + callbackInterval = fileSize / 100 + if callbackInterval == 0 { + callbackInterval = 1 + } + return &TaskLocalFile{ + file: file, + size: fileSize, + progressCallback: progressCallback, + callbackTimes: 100, + nextCallbackAt: callbackInterval, + callbackInterval: callbackInterval, + }, nil +}