feat: add import command and batch import functionality

- Implemented the `/import` command to allow users to import files from storage to Telegram.
- Added support for listing files in storage and filtering based on regex patterns.
- Created a batch import task to handle multiple file uploads concurrently.
- Introduced progress tracking for batch imports, providing real-time updates to users.
- Enhanced storage interfaces to support file listing and reading capabilities.
- Updated localization files for the new import command and its usage instructions.
- Added utility functions for file size formatting and speed calculation.
- Refactored Telegram storage handling to support reading from non-seekable streams.
This commit is contained in:
krau
2026-01-17 18:59:09 +08:00
parent 3ce00884a0
commit eda0756f0c
18 changed files with 902 additions and 69 deletions

View File

@@ -0,0 +1,180 @@
package handlers
import (
"fmt"
"regexp"
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/config"
storconfig "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/core/tasks/batchimport"
"github.com/krau/SaveAny-Bot/pkg/storagetypes"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
func handleImportCmd(ctx *ext.Context, update *ext.Update) error {
logger := log.FromContext(ctx)
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 3 {
ctx.Reply(update, ext.ReplyTextString("用法: /import <storage_name> <dir_path> [target_chat_id] [filter]\n\n"+
"示例:\n"+
"/import 本机1 /downloads\n"+
"/import MyAlist /media/photos -1001234567890\n"+
"/import MyLocal /backup \".*\\\\.mp4$\""), nil)
return dispatcher.EndGroups
}
storageName := args[1]
dirPath := args[2]
userID := update.GetUserChat().GetID()
// 1. 获取源存储端
stor, err := storage.GetStorageByUserIDAndName(ctx, userID, storageName)
if err != nil {
logger.Errorf("Failed to get storage by user ID and name: %s", err)
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("存储端 '%s' 不存在或您无权访问: %v", storageName, err)), nil)
return dispatcher.EndGroups
}
// 2. 检查是否支持列举
listable, ok := stor.(storage.StorageListable)
if !ok {
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("存储端 '%s' 不支持列举文件功能", storageName)), nil)
return dispatcher.EndGroups
}
// 3. 检查是否支持读取
_, ok = stor.(storage.StorageReadable)
if !ok {
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("存储端 '%s' 不支持读取文件功能", storageName)), nil)
return dispatcher.EndGroups
}
// 4. 获取目标 Telegram 存储
telegramStorage, err := storage.GetTelegramStorageByUserID(ctx, userID)
if err != nil {
ctx.Reply(update, ext.ReplyTextString(fmt.Sprintf("未找到可用的 Telegram 存储: %v", err)), nil)
return dispatcher.EndGroups
}
// 5. 列举目录文件
replied, err := ctx.Reply(update, ext.ReplyTextString("正在获取文件列表..."), nil)
if err != nil {
logger.Errorf("Failed to reply: %s", err)
return dispatcher.EndGroups
}
files, err := listable.ListFiles(ctx, dirPath)
if err != nil {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: replied.ID,
Message: fmt.Sprintf("获取文件列表失败: %v", err),
})
return dispatcher.EndGroups
}
// 6. 过滤文件
var filter *regexp.Regexp
if len(args) >= 5 {
filter, err = regexp.Compile(args[4])
if err != nil {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: replied.ID,
Message: fmt.Sprintf("正则表达式无效: %v", err),
})
return dispatcher.EndGroups
}
}
filteredFiles := make([]storagetypes.FileInfo, 0)
for _, file := range files {
if file.IsDir {
continue
}
if filter != nil && !filter.MatchString(file.Name) {
continue
}
filteredFiles = append(filteredFiles, file)
}
if len(filteredFiles) == 0 {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: replied.ID,
Message: "目录中没有可导入的文件",
})
return dispatcher.EndGroups
}
// 7. 解析目标 Chat ID
// Get default chat_id from Telegram storage config
targetChatID := int64(0)
if telegramCfg := config.C().GetStorageByName(telegramStorage.Name()); telegramCfg != nil {
if tgCfg, ok := telegramCfg.(*storconfig.TelegramStorageConfig); ok {
targetChatID = tgCfg.ChatID
}
}
if len(args) >= 4 {
parsedChatID, err := tgutil.ParseChatID(ctx, args[3])
if err != nil {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: replied.ID,
Message: fmt.Sprintf("无效的 Chat ID: %v", err),
})
return dispatcher.EndGroups
}
targetChatID = parsedChatID
}
if targetChatID == 0 {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: replied.ID,
Message: "未指定目标频道 ID且 Telegram 存储未配置默认 chat_id",
})
return dispatcher.EndGroups
}
// 8. 创建任务元素
elems := make([]batchimport.TaskElement, 0, len(filteredFiles))
var totalSize int64
for _, file := range filteredFiles {
elem := batchimport.NewTaskElement(stor, file, telegramStorage, targetChatID)
elems = append(elems, *elem)
totalSize += file.Size
}
// 9. 创建并添加任务
taskID := xid.New().String()
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
task := batchimport.NewBatchImportTask(
taskID,
injectCtx,
elems,
batchimport.NewProgressTracker(replied.ID, userID),
true, // IgnoreErrors
)
if err := core.AddTask(injectCtx, task); err != nil {
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: replied.ID,
Message: fmt.Sprintf("添加任务失败: %v", err),
})
return dispatcher.EndGroups
}
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
ID: replied.ID,
Message: fmt.Sprintf("✅ 已添加 %d 个文件到导入队列\n总大小: %.2f MB\n任务 ID: %s", len(elems), float64(totalSize)/(1024*1024), taskID),
})
return dispatcher.EndGroups
}

View File

@@ -31,6 +31,7 @@ var CommandHandlers = []DescCommandHandler{
{"dl", i18nk.BotMsgCmdDl, handleDlCmd},
{"aria2dl", i18nk.BotMsgCmdAria2dl, handleAria2DlCmd},
{"ytdlp", i18nk.BotMsgCmdYtdlp, handleYtdlpCmd},
{"import", i18nk.BotMsgCmdImport, handleImportCmd},
{"task", i18nk.BotMsgCmdTask, handleTaskCmd},
{"cancel", i18nk.BotMsgCmdCancel, handleCancelCmd},
{"config", i18nk.BotMsgCmdConfig, handleConfigCmd},

View File

@@ -21,6 +21,7 @@ const (
BotMsgCmdDl Key = "bot.msg.cmd.dl"
BotMsgCmdFnametmpl Key = "bot.msg.cmd.fnametmpl"
BotMsgCmdHelp Key = "bot.msg.cmd.help"
BotMsgCmdImport Key = "bot.msg.cmd.import"
BotMsgCmdLswatch Key = "bot.msg.cmd.lswatch"
BotMsgCmdParser Key = "bot.msg.cmd.parser"
BotMsgCmdRule Key = "bot.msg.cmd.rule"

View File

@@ -29,6 +29,7 @@ bot:
/silent - Toggle silent mode
/storage - Set default storage
/save [custom filename] - Save file
/import <storage_name> <dir_path> [channel_id] [filter] - Import files from storage to Telegram
/dir - Manage storage directories
/rule - Manage rules
/config - Modify configuration
@@ -52,6 +53,7 @@ bot:
dl: "Download files from given links"
aria2dl: "Download files using Aria2"
ytdlp: "Download video/audio using yt-dlp"
import: "Import files from storage to Telegram"
task: "Manage task queue"
cancel: "Cancel task"
watch: "Watch chats (UserBot)"

View File

@@ -30,6 +30,7 @@ bot:
/storage - 设置默认存储位置
/save [自定义文件名] - 保存文件
/dl <链接1> <链接2> ... - 下载给定链接的文件
/import <存储名> <目录路径> [频道ID] [过滤器] - 从存储端导入文件到 Telegram
/dir - 管理存储目录
/rule - 管理规则
/config - 修改配置
@@ -53,6 +54,7 @@ bot:
dl: "下载给定链接的文件"
aria2dl: "使用 Aria2 下载给定链接的文件"
ytdlp: "使用 yt-dlp 下载视频/音频"
import: "从存储端导入文件到 Telegram"
task: "管理任务队列"
cancel: "取消任务"
watch: "监听聊天(UserBot)"

View File

@@ -1,6 +1,9 @@
package dlutil
import "time"
import (
"fmt"
"time"
)
var threadsLevels = []struct {
threads int
@@ -31,3 +34,23 @@ func GetSpeed(downloaded int64, startTime time.Time) float64 {
}
return float64(downloaded) / elapsed
}
// FormatSize formats a byte size as a human-readable string
func FormatSize(bytes int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case bytes >= GB:
return fmt.Sprintf("%.2f GB", float64(bytes)/float64(GB))
case bytes >= MB:
return fmt.Sprintf("%.2f MB", float64(bytes)/float64(MB))
case bytes >= KB:
return fmt.Sprintf("%.2f KB", float64(bytes)/float64(KB))
default:
return fmt.Sprintf("%d B", bytes)
}
}

View File

@@ -0,0 +1,139 @@
package batchimport
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/storage"
"golang.org/x/sync/errgroup"
)
// Execute implements core.Executable.
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("batch_import[%s]", t.ID))
logger.Info("Starting batch import task")
t.Progress.OnStart(ctx, t)
workers := config.C().Workers
eg, gctx := errgroup.WithContext(ctx)
eg.SetLimit(workers)
for _, elem := range t.elems {
eg.Go(func() error {
t.processingMu.RLock()
if t.processing[elem.ID] != nil {
t.processingMu.RUnlock()
return fmt.Errorf("element with ID %s is already being processed", elem.ID)
}
t.processingMu.RUnlock()
t.processingMu.Lock()
t.processing[elem.ID] = &elem
t.processingMu.Unlock()
defer func() {
t.processingMu.Lock()
delete(t.processing, elem.ID)
t.processingMu.Unlock()
}()
err := t.processElement(gctx, elem)
if err != nil && !t.IgnoreErrors {
return err
}
if err != nil {
t.failed[elem.ID] = err
logger.Errorf("Failed to process file %s: %v", elem.FileInfo.Name, err)
}
return nil
})
}
err := eg.Wait()
if err != nil {
logger.Errorf("Error during batch import processing: %v", err)
} else {
logger.Info("Batch import task completed successfully")
}
t.Progress.OnDone(ctx, t, err)
return err
}
func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.FileInfo.Name))
// 检查源存储是否支持读取
readableStorage, ok := elem.SourceStorage.(storage.StorageReadable)
if !ok {
return fmt.Errorf("source storage %s does not support reading", elem.SourceStorage.Name())
}
logger.Info("Opening file from source storage")
reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer reader.Close()
// 构造 Telegram 存储路径: /<chat_id>/<filename>
storagePath := fmt.Sprintf("/%d/%s", elem.TargetChatID, elem.FileInfo.Name)
// 注入文件大小到 context
ctx = context.WithValue(ctx, ctxkey.ContentLength, size)
if config.C().Stream {
if err := elem.TargetStorage.Save(ctx, reader, storagePath); err != nil {
return fmt.Errorf("failed to upload file to telegram: %w", err)
}
} else {
logger.Info("Downloading to temporary file for ReadSeeker support")
tempFile, err := t.downloadToTemp(reader, elem.FileInfo.Name)
if err != nil {
return fmt.Errorf("failed to download to temp: %w", err)
}
defer os.Remove(tempFile.Name())
defer tempFile.Close()
if _, err := tempFile.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek temp file: %w", err)
}
logger.Infof("Uploading file to Telegram storage (size: %d bytes)", size)
if err := elem.TargetStorage.Save(ctx, tempFile, storagePath); err != nil {
return fmt.Errorf("failed to upload file to telegram: %w", err)
}
}
t.uploaded.Add(size)
t.Progress.OnProgress(ctx, t)
logger.Info("File uploaded successfully")
return nil
}
func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, error) {
tempDir := config.C().Temp.BasePath
if tempDir == "" {
tempDir = os.TempDir()
}
tempFile, err := os.CreateTemp(tempDir, filepath.Base(filename)+"-*.tmp")
if err != nil {
return nil, fmt.Errorf("failed to create temp file: %w", err)
}
if _, err := io.Copy(tempFile, reader); err != nil {
tempFile.Close()
os.Remove(tempFile.Name())
return nil, fmt.Errorf("failed to copy to temp file: %w", err)
}
return tempFile, nil
}

View File

@@ -0,0 +1,225 @@
package batchimport
import (
"context"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
)
type ProgressTracker interface {
OnStart(ctx context.Context, info TaskInfo)
OnProgress(ctx context.Context, info TaskInfo)
OnDone(ctx context.Context, info TaskInfo, err error)
}
type Progress struct {
MessageID int
ChatID int64
start time.Time
lastUpdatePercent atomic.Int32
}
func NewProgressTracker(messageID int, chatID int64) ProgressTracker {
return &Progress{
MessageID: messageID,
ChatID: chatID,
}
}
func (p *Progress) OnStart(ctx context.Context, info TaskInfo) {
p.start = time.Now()
p.lastUpdatePercent.Store(0)
log.FromContext(ctx).Debugf("Batch import task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID)
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain("正在导入: "),
styling.Code(fmt.Sprintf("%.2f MB (%d个文件)", float64(info.TotalSize())/(1024*1024), info.Count())),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
},
})
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
}
}
func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) {
if !shouldUpdateProgress(info.TotalSize(), info.Uploaded(), int(p.lastUpdatePercent.Load())) {
return
}
percent := int((info.Uploaded() * 100) / info.TotalSize())
if p.lastUpdatePercent.Load() == int32(percent) {
return
}
p.lastUpdatePercent.Store(int32(percent))
log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Uploaded(), info.TotalSize())
entityBuilder := entity.Builder{}
var progressText strings.Builder
progressText.WriteString(fmt.Sprintf("导入进度: %d%%\n", percent))
progressText.WriteString(fmt.Sprintf("已上传: %.2f MB / %.2f MB\n",
float64(info.Uploaded())/(1024*1024),
float64(info.TotalSize())/(1024*1024)))
if p.start.Unix() > 0 {
elapsed := time.Since(p.start)
speed := float64(info.Uploaded()) / elapsed.Seconds()
progressText.WriteString(fmt.Sprintf("速度: %s/s\n", dlutil.FormatSize(int64(speed))))
if info.Uploaded() > 0 {
remaining := time.Duration(float64(info.TotalSize()-info.Uploaded()) / speed * float64(time.Second))
progressText.WriteString(fmt.Sprintf("剩余时间: %s\n", formatDuration(remaining)))
}
}
processing := info.Processing()
if len(processing) > 0 {
progressText.WriteString("\n正在处理:\n")
for i, elem := range processing {
if i >= 3 {
progressText.WriteString(fmt.Sprintf("...和其他 %d 个文件\n", len(processing)-3))
break
}
fmt.Fprintf(&progressText, "- %s\n", elem.FileName())
}
}
if err := styling.Perform(&entityBuilder,
styling.Plain(progressText.String()),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(info.TaskID()),
},
},
},
})
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
}
}
func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) {
log.FromContext(ctx).Debugf("Batch import task progress tracking done for message %d in chat %d", p.MessageID, p.ChatID)
entityBuilder := entity.Builder{}
var resultText strings.Builder
if err != nil {
resultText.WriteString("❌ 导入失败\n")
fmt.Fprintf(&resultText, "错误: %v\n", err)
} else {
resultText.WriteString("✅ 导入完成\n")
}
elapsed := time.Since(p.start)
resultText.WriteString(fmt.Sprintf("\n总文件数: %d\n", info.Count()))
resultText.WriteString(fmt.Sprintf("总大小: %.2f MB\n", float64(info.TotalSize())/(1024*1024)))
resultText.WriteString(fmt.Sprintf("已上传: %.2f MB\n", float64(info.Uploaded())/(1024*1024)))
resultText.WriteString(fmt.Sprintf("耗时: %s\n", formatDuration(elapsed)))
if elapsed.Seconds() > 0 {
avgSpeed := float64(info.Uploaded()) / elapsed.Seconds()
resultText.WriteString(fmt.Sprintf("平均速度: %s/s\n", dlutil.FormatSize(int64(avgSpeed))))
}
failedFiles := info.FailedFiles()
if len(failedFiles) > 0 {
fmt.Fprintf(&resultText, "\n失败文件数: %d\n", len(failedFiles))
for i, name := range failedFiles {
if i >= 5 {
fmt.Fprintf(&resultText, "...和其他 %d 个文件\n", len(failedFiles)-5)
break
}
fmt.Fprintf(&resultText, "- %s\n", name)
}
}
if err := styling.Perform(&entityBuilder,
styling.Plain(resultText.String()),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.MessageID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.ChatID, req)
}
}
func shouldUpdateProgress(total, current int64, lastPercent int) bool {
if total == 0 {
return false
}
currentPercent := int((current * 100) / total)
return currentPercent > lastPercent && currentPercent%5 == 0
}
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d -= h * time.Hour
m := d / time.Minute
d -= m * time.Minute
s := d / time.Second
if h > 0 {
return fmt.Sprintf("%dh%dm%ds", h, m, s)
}
if m > 0 {
return fmt.Sprintf("%dm%ds", m, s)
}
return fmt.Sprintf("%ds", s)
}

View File

@@ -0,0 +1,97 @@
package batchimport
import (
"context"
"fmt"
"sync"
"sync/atomic"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/storagetypes"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
var _ core.Executable = (*Task)(nil)
type TaskElement struct {
ID string
SourceStorage storage.Storage
SourcePath string
FileInfo storagetypes.FileInfo
TargetStorage storage.Storage
TargetChatID int64
}
type Task struct {
ID string
ctx context.Context
elems []TaskElement
Progress ProgressTracker
IgnoreErrors bool
uploaded atomic.Int64
totalSize int64
processing map[string]TaskElementInfo
processingMu sync.RWMutex
failed map[string]error
}
// Title implements core.Executable.
func (t *Task) Title() string {
return fmt.Sprintf("[%s](%d files/%.2fMB)", t.Type(), len(t.elems), float64(t.totalSize)/(1024*1024))
}
// Type implements core.Executable.
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeBatchimport
}
// TaskID implements core.Executable.
func (t *Task) TaskID() string {
return t.ID
}
func NewTaskElement(
sourceStorage storage.Storage,
fileInfo storagetypes.FileInfo,
targetStorage storage.Storage,
targetChatID int64,
) *TaskElement {
id := xid.New().String()
return &TaskElement{
ID: id,
SourceStorage: sourceStorage,
SourcePath: fileInfo.Path,
FileInfo: fileInfo,
TargetStorage: targetStorage,
TargetChatID: targetChatID,
}
}
func NewBatchImportTask(
id string,
ctx context.Context,
elems []TaskElement,
progress ProgressTracker,
ignoreErrors bool,
) *Task {
task := &Task{
ID: id,
ctx: ctx,
elems: elems,
Progress: progress,
uploaded: atomic.Int64{},
totalSize: func() int64 {
var total int64
for _, elem := range elems {
total += elem.FileInfo.Size
}
return total
}(),
processing: make(map[string]TaskElementInfo),
IgnoreErrors: ignoreErrors,
failed: make(map[string]error),
}
return task
}

View File

@@ -0,0 +1,73 @@
package batchimport
type TaskElementInfo interface {
FileName() string
FileSize() int64
GetSourcePath() string
SourceStorageName() string
}
func (e *TaskElement) FileName() string {
return e.FileInfo.Name
}
func (e *TaskElement) FileSize() int64 {
return e.FileInfo.Size
}
func (e *TaskElement) GetSourcePath() string {
return e.SourcePath
}
func (e *TaskElement) SourceStorageName() string {
return e.SourceStorage.Name()
}
type TaskInfo interface {
TaskID() string
TotalSize() int64
Uploaded() int64
Count() int
Processing() []TaskElementInfo
FailedFiles() []string
}
func (t *Task) TotalSize() int64 {
return t.totalSize
}
func (t *Task) Uploaded() int64 {
return t.uploaded.Load()
}
func (t *Task) Count() int {
return len(t.elems)
}
func (t *Task) Processing() []TaskElementInfo {
t.processingMu.RLock()
defer t.processingMu.RUnlock()
result := make([]TaskElementInfo, 0, len(t.processing))
for _, elem := range t.processing {
result = append(result, elem)
}
return result
}
func (t *Task) FailedFiles() []string {
t.processingMu.RLock()
defer t.processingMu.RUnlock()
result := make([]string, 0, len(t.failed))
for id := range t.failed {
// Find the element by ID
for _, elem := range t.elems {
if elem.ID == id {
result = append(result, elem.FileInfo.Name)
break
}
}
}
return result
}

View File

@@ -1,5 +1,5 @@
package tasktype
//go:generate go-enum --values --names --flag --nocase
// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp)
// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp,batchimport)
type TaskType string

View File

@@ -24,6 +24,8 @@ const (
TaskTypeAria2 TaskType = "aria2"
// TaskTypeYtdlp is a TaskType of type ytdlp.
TaskTypeYtdlp TaskType = "ytdlp"
// TaskTypeBatchimport is a TaskType of type batchimport.
TaskTypeBatchimport TaskType = "batchimport"
)
var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", "))
@@ -35,6 +37,7 @@ var _TaskTypeNames = []string{
string(TaskTypeDirectlinks),
string(TaskTypeAria2),
string(TaskTypeYtdlp),
string(TaskTypeBatchimport),
}
// TaskTypeNames returns a list of possible string values of TaskType.
@@ -53,6 +56,7 @@ func TaskTypeValues() []TaskType {
TaskTypeDirectlinks,
TaskTypeAria2,
TaskTypeYtdlp,
TaskTypeBatchimport,
}
}
@@ -75,6 +79,7 @@ var _TaskTypeValue = map[string]TaskType{
"directlinks": TaskTypeDirectlinks,
"aria2": TaskTypeAria2,
"ytdlp": TaskTypeYtdlp,
"batchimport": TaskTypeBatchimport,
}
// ParseTaskType attempts to convert a string to a TaskType.

View File

@@ -0,0 +1,12 @@
package storagetypes
import "time"
// FileInfo 表示文件元数据
type FileInfo struct {
Name string
Path string
Size int64
IsDir bool
ModTime time.Time
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
)
var UserStorages = make(map[int64][]Storage)
@@ -79,3 +80,14 @@ func LoadStorages(ctx context.Context) {
UserStorages[int64(user)] = GetUserStorages(ctx, int64(user))
}
}
// GetTelegramStorageByUserID returns the first enabled Telegram storage for the user
func GetTelegramStorageByUserID(ctx context.Context, chatID int64) (Storage, error) {
storages := GetUserStorages(ctx, chatID)
for _, stor := range storages {
if stor.Type() == storenum.Telegram {
return stor, nil
}
}
return nil, fmt.Errorf("no telegram storage found for user %d", chatID)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/duke-git/lancet/v2/fileutil"
config "github.com/krau/SaveAny-Bot/config/storage"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/krau/SaveAny-Bot/pkg/storagetypes"
)
type Local struct {
@@ -81,3 +82,51 @@ func (l *Local) Exists(ctx context.Context, storagePath string) bool {
}
return fileutil.IsExist(absPath)
}
// ListFiles implements StorageListable interface
func (l *Local) ListFiles(ctx context.Context, dirPath string) ([]storagetypes.FileInfo, error) {
absPath := l.JoinStoragePath(dirPath)
entries, err := os.ReadDir(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read directory %s: %w", absPath, err)
}
files := make([]storagetypes.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
l.logger.Warnf("Failed to get file info for %s: %v", entry.Name(), err)
continue
}
filePath := filepath.Join(dirPath, entry.Name())
files = append(files, storagetypes.FileInfo{
Name: entry.Name(),
Path: filePath,
Size: info.Size(),
IsDir: entry.IsDir(),
ModTime: info.ModTime(),
})
}
return files, nil
}
// OpenFile implements StorageReadable interface
func (l *Local) OpenFile(ctx context.Context, filePath string) (io.ReadCloser, int64, error) {
absPath := l.JoinStoragePath(filePath)
file, err := os.Open(absPath)
if err != nil {
return nil, 0, fmt.Errorf("failed to open file %s: %w", absPath, err)
}
stat, err := file.Stat()
if err != nil {
file.Close()
return nil, 0, fmt.Errorf("failed to stat file %s: %w", absPath, err)
}
return file, stat.Size(), nil
}

View File

@@ -7,6 +7,7 @@ import (
storcfg "github.com/krau/SaveAny-Bot/config/storage"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/krau/SaveAny-Bot/pkg/storagetypes"
"github.com/krau/SaveAny-Bot/storage/alist"
"github.com/krau/SaveAny-Bot/storage/local"
"github.com/krau/SaveAny-Bot/storage/minio"
@@ -30,6 +31,18 @@ type StorageCannotStream interface {
CannotStream() string
}
// StorageListable 表示支持列举目录内容的存储
type StorageListable interface {
Storage
ListFiles(ctx context.Context, dirPath string) ([]storagetypes.FileInfo, error)
}
// StorageReadable 表示支持读取文件内容的存储
type StorageReadable interface {
Storage
OpenFile(ctx context.Context, filePath string) (io.ReadCloser, int64, error)
}
var Storages = make(map[string]Storage)
type StorageConstructor func() Storage

View File

@@ -99,12 +99,6 @@ func (w *splitWriter) finalize() error {
}
func CreateSplitZip(ctx context.Context, reader io.Reader, size int64, fileName, outputBase string, partSize int64) error {
// seek the reader if possible
if rs, ok := reader.(io.ReadSeeker); ok {
if _, err := rs.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek reader: %w", err)
}
}
outputDir := filepath.Dir(outputBase)
if err := os.MkdirAll(outputDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)

View File

@@ -92,9 +92,6 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er
return nil
}
rs, seekable := r.(io.ReadSeeker)
if !seekable || rs == nil {
return fmt.Errorf("reader must implement io.ReadSeeker")
}
splitSize := t.config.SplitSizeMB * 1024 * 1024
if splitSize <= 0 {
splitSize = DefaultSplitSize
@@ -123,88 +120,96 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er
}
chatID = cid
}
mtype, err := mimetype.DetectReader(rs)
if err != nil {
return fmt.Errorf("failed to detect mimetype: %w", err)
}
if filename == "" {
filename = xid.New().String() + mtype.Extension()
}
upler := uploader.NewUploader(tctx.Raw).
WithPartSize(tglimit.MaxUploadPartSize).
WithThreads(dlutil.BestThreads(size, config.C().Threads))
peer := tryGetInputPeer(tctx, chatID)
if peer == nil || peer.Zero() {
return fmt.Errorf("failed to get input peer for chat ID %d", chatID)
}
var mtype *mimetype.MIME
if seekable {
var err error
mtype, err = mimetype.DetectReader(r)
if err != nil {
return fmt.Errorf("failed to detect mimetype: %w", err)
}
if filename == "" {
filename = xid.New().String() + mtype.Extension()
}
if _, err := rs.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek reader: %w", err)
if _, err := rs.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek reader: %w", err)
}
}
upler := uploader.NewUploader(tctx.Raw).
WithPartSize(tglimit.MaxUploadPartSize).
WithThreads(dlutil.BestThreads(size, config.C().Threads))
if size > splitSize {
// large file, use split uploader
return t.splitUpload(tctx, rs, filename, upler, peer, size, splitSize)
return t.splitUpload(tctx, r, filename, upler, peer, size, splitSize)
}
var file tg.InputFileClass
if size < 0 {
file, err = upler.FromReader(ctx, filename, rs)
var err error
if size <= 0 {
file, err = upler.FromReader(ctx, filename, r)
} else {
file, err = upler.Upload(ctx, uploader.NewUpload(filename, rs, size))
file, err = upler.Upload(ctx, uploader.NewUpload(filename, r, size))
}
if err != nil {
return fmt.Errorf("failed to upload file to telegram: %w", err)
}
caption := styling.Plain(filename)
forceFile := t.config.ForceFile
if strings.HasPrefix(mtype.String(), "image/") && size >= tglimit.MaxPhotoSize {
if mtype != nil && strings.HasPrefix(mtype.String(), "image/") && size >= tglimit.MaxPhotoSize {
forceFile = true
}
doc := message.UploadedDocument(file, caption).
Filename(filename).
ForceFile(forceFile).
MIME(mtype.String())
ForceFile(forceFile)
if mtype != nil {
doc = doc.MIME(mtype.String())
}
var media message.MediaOption = doc
switch mtypeStr := mtype.String(); {
case strings.HasPrefix(mtypeStr, "video/"):
media = doc.Video().SupportsStreaming()
thumb, err := extractThumbFrame(rs)
if err == nil {
thumb, err := upler.FromBytes(ctx, "thumb.jpg", thumb)
if mtype != nil && rs != nil {
switch mtypeStr := mtype.String(); {
case strings.HasPrefix(mtypeStr, "video/"):
media = doc.Video().SupportsStreaming()
thumb, err := extractThumbFrame(rs)
if err == nil {
doc = doc.Thumb(thumb)
thumb, err := upler.FromBytes(ctx, "thumb.jpg", thumb)
if err == nil {
doc = doc.Thumb(thumb)
}
}
rs.Seek(0, io.SeekStart)
switch mtypeStr {
case "video/mp4":
info, err := getMP4Meta(rs)
if err != nil {
// Fallback to ffprobe if gomedia fails (e.g., malformed MP4)
rs.Seek(0, io.SeekStart)
info, err = getVideoMetadata(rs)
}
if err == nil {
media = doc.Video().
Duration(time.Duration(info.Duration)*time.Second).
Resolution(info.Width, info.Height).
SupportsStreaming()
}
default:
info, err := getVideoMetadata(rs)
if err == nil {
media = doc.Video().
Duration(time.Duration(info.Duration)*time.Second).
Resolution(info.Width, info.Height).
SupportsStreaming()
}
}
case strings.HasPrefix(mtypeStr, "audio/"):
media = doc.Audio().Title(filename)
case strings.HasPrefix(mtypeStr, "image/") && !strings.HasSuffix(mtypeStr, "webp"):
media = message.UploadedPhoto(file, caption)
}
rs.Seek(0, io.SeekStart)
switch mtypeStr {
case "video/mp4":
info, err := getMP4Meta(rs)
if err != nil {
// Fallback to ffprobe if gomedia fails (e.g., malformed MP4)
rs.Seek(0, io.SeekStart)
info, err = getVideoMetadata(rs)
}
if err == nil {
media = doc.Video().
Duration(time.Duration(info.Duration)*time.Second).
Resolution(info.Width, info.Height).
SupportsStreaming()
}
default:
info, err := getVideoMetadata(rs)
if err == nil {
media = doc.Video().
Duration(time.Duration(info.Duration)*time.Second).
Resolution(info.Width, info.Height).
SupportsStreaming()
}
}
case strings.HasPrefix(mtypeStr, "audio/"):
media = doc.Audio().Title(filename)
case strings.HasPrefix(mtypeStr, "image/") && !strings.HasSuffix(mtypeStr, "webp"):
media = message.UploadedPhoto(file, caption)
}
sender := tctx.Sender
_, err = sender.WithUploader(upler).To(peer).Media(ctx, media)
@@ -215,7 +220,7 @@ func (t *Telegram) CannotStream() string {
return "Telegram storage must use a ReaderSeeker"
}
func (t *Telegram) splitUpload(ctx *ext.Context, rs io.ReadSeeker, filename string, upler *uploader.Uploader, peer tg.InputPeerClass, fileSize, splitSize int64) error {
func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, upler *uploader.Uploader, peer tg.InputPeerClass, fileSize, splitSize int64) error {
tempId := xid.New().String()
outputBase := filepath.Join(config.C().Temp.BasePath, tempId, strings.Split(filename, ".")[0])
defer func() {
@@ -224,7 +229,7 @@ func (t *Telegram) splitUpload(ctx *ext.Context, rs io.ReadSeeker, filename stri
log.FromContext(ctx).Warnf("Failed to cleanup temp split files: %s", err)
}
}()
if err := CreateSplitZip(ctx, rs, fileSize, filename, outputBase, splitSize); err != nil {
if err := CreateSplitZip(ctx, r, fileSize, filename, outputBase, splitSize); err != nil {
return fmt.Errorf("failed to create split zip: %w", err)
}
matched, err := filepath.Glob(outputBase + ".z*")