feat: add import command and batch import functionality
- Implemented the `/import` command to allow users to import files from storage to Telegram. - Added support for listing files in storage and filtering based on regex patterns. - Created a batch import task to handle multiple file uploads concurrently. - Introduced progress tracking for batch imports, providing real-time updates to users. - Enhanced storage interfaces to support file listing and reading capabilities. - Updated localization files for the new import command and its usage instructions. - Added utility functions for file size formatting and speed calculation. - Refactored Telegram storage handling to support reading from non-seekable streams.
This commit is contained in:
139
core/tasks/batchimport/execute.go
Normal file
139
core/tasks/batchimport/execute.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package batchimport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Execute implements core.Executable.
|
||||
func (t *Task) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("batch_import[%s]", t.ID))
|
||||
logger.Info("Starting batch import task")
|
||||
t.Progress.OnStart(ctx, t)
|
||||
|
||||
workers := config.C().Workers
|
||||
eg, gctx := errgroup.WithContext(ctx)
|
||||
eg.SetLimit(workers)
|
||||
|
||||
for _, elem := range t.elems {
|
||||
eg.Go(func() error {
|
||||
t.processingMu.RLock()
|
||||
if t.processing[elem.ID] != nil {
|
||||
t.processingMu.RUnlock()
|
||||
return fmt.Errorf("element with ID %s is already being processed", elem.ID)
|
||||
}
|
||||
t.processingMu.RUnlock()
|
||||
|
||||
t.processingMu.Lock()
|
||||
t.processing[elem.ID] = &elem
|
||||
t.processingMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
t.processingMu.Lock()
|
||||
delete(t.processing, elem.ID)
|
||||
t.processingMu.Unlock()
|
||||
}()
|
||||
|
||||
err := t.processElement(gctx, elem)
|
||||
if err != nil && !t.IgnoreErrors {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
t.failed[elem.ID] = err
|
||||
logger.Errorf("Failed to process file %s: %v", elem.FileInfo.Name, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
logger.Errorf("Error during batch import processing: %v", err)
|
||||
} else {
|
||||
logger.Info("Batch import task completed successfully")
|
||||
}
|
||||
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.FileInfo.Name))
|
||||
|
||||
// 检查源存储是否支持读取
|
||||
readableStorage, ok := elem.SourceStorage.(storage.StorageReadable)
|
||||
if !ok {
|
||||
return fmt.Errorf("source storage %s does not support reading", elem.SourceStorage.Name())
|
||||
}
|
||||
|
||||
logger.Info("Opening file from source storage")
|
||||
reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// 构造 Telegram 存储路径: /<chat_id>/<filename>
|
||||
storagePath := fmt.Sprintf("/%d/%s", elem.TargetChatID, elem.FileInfo.Name)
|
||||
|
||||
// 注入文件大小到 context
|
||||
ctx = context.WithValue(ctx, ctxkey.ContentLength, size)
|
||||
|
||||
if config.C().Stream {
|
||||
if err := elem.TargetStorage.Save(ctx, reader, storagePath); err != nil {
|
||||
return fmt.Errorf("failed to upload file to telegram: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Downloading to temporary file for ReadSeeker support")
|
||||
tempFile, err := t.downloadToTemp(reader, elem.FileInfo.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download to temp: %w", err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
defer tempFile.Close()
|
||||
|
||||
if _, err := tempFile.Seek(0, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek temp file: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("Uploading file to Telegram storage (size: %d bytes)", size)
|
||||
if err := elem.TargetStorage.Save(ctx, tempFile, storagePath); err != nil {
|
||||
return fmt.Errorf("failed to upload file to telegram: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.uploaded.Add(size)
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
|
||||
logger.Info("File uploaded successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, error) {
|
||||
tempDir := config.C().Temp.BasePath
|
||||
if tempDir == "" {
|
||||
tempDir = os.TempDir()
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(tempDir, filepath.Base(filename)+"-*.tmp")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(tempFile, reader); err != nil {
|
||||
tempFile.Close()
|
||||
os.Remove(tempFile.Name())
|
||||
return nil, fmt.Errorf("failed to copy to temp file: %w", err)
|
||||
}
|
||||
|
||||
return tempFile, nil
|
||||
}
|
||||
225
core/tasks/batchimport/progress.go
Normal file
225
core/tasks/batchimport/progress.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package batchimport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
)
|
||||
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, info TaskInfo)
|
||||
OnProgress(ctx context.Context, info TaskInfo)
|
||||
OnDone(ctx context.Context, info TaskInfo, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
MessageID int
|
||||
ChatID int64
|
||||
start time.Time
|
||||
lastUpdatePercent atomic.Int32
|
||||
}
|
||||
|
||||
func NewProgressTracker(messageID int, chatID int64) ProgressTracker {
|
||||
return &Progress{
|
||||
MessageID: messageID,
|
||||
ChatID: chatID,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
|
||||
p.start = time.Now()
|
||||
p.lastUpdatePercent.Store(0)
|
||||
log.FromContext(ctx).Debugf("Batch import task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain("正在导入: "),
|
||||
styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) {
|
||||
if !shouldUpdateProgress(info.TotalSize(), info.Uploaded(), int(p.lastUpdatePercent.Load())) {
|
||||
return
|
||||
}
|
||||
percent := int((info.Uploaded() * 100) / info.TotalSize())
|
||||
if p.lastUpdatePercent.Load() == int32(percent) {
|
||||
return
|
||||
}
|
||||
p.lastUpdatePercent.Store(int32(percent))
|
||||
|
||||
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Uploaded(), info.TotalSize())
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
var progressText strings.Builder
|
||||
|
||||
progressText.WriteString(fmt.Sprintf("导入进度: %d%%\n", percent))
|
||||
progressText.WriteString(fmt.Sprintf("已上传: %.2f MB / %.2f MB\n",
|
||||
float64(info.Uploaded())/(1024*1024),
|
||||
float64(info.TotalSize())/(1024*1024)))
|
||||
|
||||
if p.start.Unix() > 0 {
|
||||
elapsed := time.Since(p.start)
|
||||
speed := float64(info.Uploaded()) / elapsed.Seconds()
|
||||
progressText.WriteString(fmt.Sprintf("速度: %s/s\n", dlutil.FormatSize(int64(speed))))
|
||||
|
||||
if info.Uploaded() > 0 {
|
||||
remaining := time.Duration(float64(info.TotalSize()-info.Uploaded()) / speed * float64(time.Second))
|
||||
progressText.WriteString(fmt.Sprintf("剩余时间: %s\n", formatDuration(remaining)))
|
||||
}
|
||||
}
|
||||
|
||||
processing := info.Processing()
|
||||
if len(processing) > 0 {
|
||||
progressText.WriteString("\n正在处理:\n")
|
||||
for i, elem := range processing {
|
||||
if i >= 3 {
|
||||
progressText.WriteString(fmt.Sprintf("...和其他 %d 个文件\n", len(processing)-3))
|
||||
break
|
||||
}
|
||||
fmt.Fprintf(&progressText, "- %s\n", elem.FileName())
|
||||
}
|
||||
}
|
||||
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(progressText.String()),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(info.TaskID()),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
|
||||
log.FromContext(ctx).Debugf("Batch import task progress tracking done for message %d in chat %d", p.MessageID, p.ChatID)
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
var resultText strings.Builder
|
||||
|
||||
if err != nil {
|
||||
resultText.WriteString("❌ 导入失败\n")
|
||||
fmt.Fprintf(&resultText, "错误: %v\n", err)
|
||||
} else {
|
||||
resultText.WriteString("✅ 导入完成\n")
|
||||
}
|
||||
|
||||
elapsed := time.Since(p.start)
|
||||
resultText.WriteString(fmt.Sprintf("\n总文件数: %d\n", info.Count()))
|
||||
resultText.WriteString(fmt.Sprintf("总大小: %.2f MB\n", float64(info.TotalSize())/(1024*1024)))
|
||||
resultText.WriteString(fmt.Sprintf("已上传: %.2f MB\n", float64(info.Uploaded())/(1024*1024)))
|
||||
resultText.WriteString(fmt.Sprintf("耗时: %s\n", formatDuration(elapsed)))
|
||||
|
||||
if elapsed.Seconds() > 0 {
|
||||
avgSpeed := float64(info.Uploaded()) / elapsed.Seconds()
|
||||
resultText.WriteString(fmt.Sprintf("平均速度: %s/s\n", dlutil.FormatSize(int64(avgSpeed))))
|
||||
}
|
||||
|
||||
failedFiles := info.FailedFiles()
|
||||
if len(failedFiles) > 0 {
|
||||
fmt.Fprintf(&resultText, "\n失败文件数: %d\n", len(failedFiles))
|
||||
for i, name := range failedFiles {
|
||||
if i >= 5 {
|
||||
fmt.Fprintf(&resultText, "...和其他 %d 个文件\n", len(failedFiles)-5)
|
||||
break
|
||||
}
|
||||
fmt.Fprintf(&resultText, "- %s\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(resultText.String()),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.MessageID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.ChatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
func shouldUpdateProgress(total, current int64, lastPercent int) bool {
|
||||
if total == 0 {
|
||||
return false
|
||||
}
|
||||
currentPercent := int((current * 100) / total)
|
||||
return currentPercent > lastPercent && currentPercent%5 == 0
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
d = d.Round(time.Second)
|
||||
h := d / time.Hour
|
||||
d -= h * time.Hour
|
||||
m := d / time.Minute
|
||||
d -= m * time.Minute
|
||||
s := d / time.Second
|
||||
|
||||
if h > 0 {
|
||||
return fmt.Sprintf("%dh%dm%ds", h, m, s)
|
||||
}
|
||||
if m > 0 {
|
||||
return fmt.Sprintf("%dm%ds", m, s)
|
||||
}
|
||||
return fmt.Sprintf("%ds", s)
|
||||
}
|
||||
97
core/tasks/batchimport/task.go
Normal file
97
core/tasks/batchimport/task.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package batchimport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/storagetypes"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
var _ core.Executable = (*Task)(nil)
|
||||
|
||||
type TaskElement struct {
|
||||
ID string
|
||||
SourceStorage storage.Storage
|
||||
SourcePath string
|
||||
FileInfo storagetypes.FileInfo
|
||||
TargetStorage storage.Storage
|
||||
TargetChatID int64
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
ID string
|
||||
ctx context.Context
|
||||
elems []TaskElement
|
||||
Progress ProgressTracker
|
||||
IgnoreErrors bool
|
||||
uploaded atomic.Int64
|
||||
totalSize int64
|
||||
processing map[string]TaskElementInfo
|
||||
processingMu sync.RWMutex
|
||||
failed map[string]error
|
||||
}
|
||||
|
||||
// Title implements core.Executable.
|
||||
func (t *Task) Title() string {
|
||||
return fmt.Sprintf("[%s](%d files/%.2fMB)", t.Type(), len(t.elems), float64(t.totalSize)/(1024*1024))
|
||||
}
|
||||
|
||||
// Type implements core.Executable.
|
||||
func (t *Task) Type() tasktype.TaskType {
|
||||
return tasktype.TaskTypeBatchimport
|
||||
}
|
||||
|
||||
// TaskID implements core.Executable.
|
||||
func (t *Task) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func NewTaskElement(
|
||||
sourceStorage storage.Storage,
|
||||
fileInfo storagetypes.FileInfo,
|
||||
targetStorage storage.Storage,
|
||||
targetChatID int64,
|
||||
) *TaskElement {
|
||||
id := xid.New().String()
|
||||
return &TaskElement{
|
||||
ID: id,
|
||||
SourceStorage: sourceStorage,
|
||||
SourcePath: fileInfo.Path,
|
||||
FileInfo: fileInfo,
|
||||
TargetStorage: targetStorage,
|
||||
TargetChatID: targetChatID,
|
||||
}
|
||||
}
|
||||
|
||||
func NewBatchImportTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
elems []TaskElement,
|
||||
progress ProgressTracker,
|
||||
ignoreErrors bool,
|
||||
) *Task {
|
||||
task := &Task{
|
||||
ID: id,
|
||||
ctx: ctx,
|
||||
elems: elems,
|
||||
Progress: progress,
|
||||
uploaded: atomic.Int64{},
|
||||
totalSize: func() int64 {
|
||||
var total int64
|
||||
for _, elem := range elems {
|
||||
total += elem.FileInfo.Size
|
||||
}
|
||||
return total
|
||||
}(),
|
||||
processing: make(map[string]TaskElementInfo),
|
||||
IgnoreErrors: ignoreErrors,
|
||||
failed: make(map[string]error),
|
||||
}
|
||||
return task
|
||||
}
|
||||
73
core/tasks/batchimport/taskinfo.go
Normal file
73
core/tasks/batchimport/taskinfo.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package batchimport
|
||||
|
||||
type TaskElementInfo interface {
|
||||
FileName() string
|
||||
FileSize() int64
|
||||
GetSourcePath() string
|
||||
SourceStorageName() string
|
||||
}
|
||||
|
||||
func (e *TaskElement) FileName() string {
|
||||
return e.FileInfo.Name
|
||||
}
|
||||
|
||||
func (e *TaskElement) FileSize() int64 {
|
||||
return e.FileInfo.Size
|
||||
}
|
||||
|
||||
func (e *TaskElement) GetSourcePath() string {
|
||||
return e.SourcePath
|
||||
}
|
||||
|
||||
func (e *TaskElement) SourceStorageName() string {
|
||||
return e.SourceStorage.Name()
|
||||
}
|
||||
|
||||
type TaskInfo interface {
|
||||
TaskID() string
|
||||
TotalSize() int64
|
||||
Uploaded() int64
|
||||
Count() int
|
||||
Processing() []TaskElementInfo
|
||||
FailedFiles() []string
|
||||
}
|
||||
|
||||
func (t *Task) TotalSize() int64 {
|
||||
return t.totalSize
|
||||
}
|
||||
|
||||
func (t *Task) Uploaded() int64 {
|
||||
return t.uploaded.Load()
|
||||
}
|
||||
|
||||
func (t *Task) Count() int {
|
||||
return len(t.elems)
|
||||
}
|
||||
|
||||
func (t *Task) Processing() []TaskElementInfo {
|
||||
t.processingMu.RLock()
|
||||
defer t.processingMu.RUnlock()
|
||||
|
||||
result := make([]TaskElementInfo, 0, len(t.processing))
|
||||
for _, elem := range t.processing {
|
||||
result = append(result, elem)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *Task) FailedFiles() []string {
|
||||
t.processingMu.RLock()
|
||||
defer t.processingMu.RUnlock()
|
||||
|
||||
result := make([]string, 0, len(t.failed))
|
||||
for id := range t.failed {
|
||||
// Find the element by ID
|
||||
for _, elem := range t.elems {
|
||||
if elem.ID == id {
|
||||
result = append(result, elem.FileInfo.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user