diff --git a/client/bot/handlers/import.go b/client/bot/handlers/import.go new file mode 100644 index 0000000..4271e89 --- /dev/null +++ b/client/bot/handlers/import.go @@ -0,0 +1,182 @@ +package handlers + +import ( + "fmt" + "regexp" + + "github.com/celestix/gotgproto/dispatcher" + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/common/i18n/i18nk" + "github.com/krau/SaveAny-Bot/common/utils/strutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/config" + storconfig "github.com/krau/SaveAny-Bot/config/storage" + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/core/tasks/batchimport" + "github.com/krau/SaveAny-Bot/pkg/storagetypes" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +func handleImportCmd(ctx *ext.Context, update *ext.Update) error { + logger := log.FromContext(ctx) + args := strutil.ParseArgsRespectQuotes(update.EffectiveMessage.Text) + + if len(args) < 3 { + ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgImportUsage, nil)), nil) + return dispatcher.EndGroups + } + + storageName := args[1] + dirPath := args[2] + + userID := update.GetUserChat().GetID() + + stor, err := storage.GetStorageByUserIDAndName(ctx, userID, storageName) + if err != nil { + logger.Errorf("Failed to get storage by user ID and name: %s", err) + ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgImportErrorStorageNotFound, map[string]any{ + "StorageName": storageName, + "Error": err, + })), nil) + return dispatcher.EndGroups + } + + listable, ok := stor.(storage.StorageListable) + if !ok { + ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgImportErrorStorageNotListable, map[string]any{ + "StorageName": storageName, + })), nil) + return dispatcher.EndGroups + } + + _, ok = stor.(storage.StorageReadable) + if !ok { + ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgImportErrorStorageNotReadable, map[string]any{ + "StorageName": storageName, + })), nil) + return dispatcher.EndGroups + } + + telegramStorage, err := storage.GetTelegramStorageByUserID(ctx, userID) + if err != nil { + ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgImportErrorNoTelegramStorage, map[string]any{ + "Error": err, + })), nil) + return dispatcher.EndGroups + } + + replied, err := ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgImportInfoFetchingFiles, nil)), nil) + if err != nil { + logger.Errorf("Failed to reply: %s", err) + return dispatcher.EndGroups + } + + files, err := listable.ListFiles(ctx, dirPath) + if err != nil { + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: i18n.T(i18nk.BotMsgImportErrorListFilesFailed, map[string]any{"Error": err}), + }) + return dispatcher.EndGroups + } + + var filter *regexp.Regexp + if len(args) >= 5 { + filter, err = regexp.Compile(args[4]) + if err != nil { + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: i18n.T(i18nk.BotMsgImportErrorInvalidRegex, map[string]any{"Error": err}), + }) + return dispatcher.EndGroups + } + } + + filteredFiles := make([]storagetypes.FileInfo, 0) + for _, file := range files { + if file.IsDir { + continue + } + if filter != nil && !filter.MatchString(file.Name) { + continue + } + filteredFiles = append(filteredFiles, file) + } + + if len(filteredFiles) == 0 { + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: i18n.T(i18nk.BotMsgImportErrorNoFilesToImport, nil), + }) + return dispatcher.EndGroups + } + + // Get default chat_id from Telegram storage config + targetChatID := int64(0) + if telegramCfg := config.C().GetStorageByName(telegramStorage.Name()); telegramCfg != nil { + if tgCfg, ok := telegramCfg.(*storconfig.TelegramStorageConfig); ok { + targetChatID = tgCfg.ChatID + } + } + + if len(args) >= 4 { + parsedChatID, err := tgutil.ParseChatID(ctx, args[3]) + if err != nil { + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: i18n.T(i18nk.BotMsgImportErrorInvalidChatId, map[string]any{"Error": err}), + }) + return dispatcher.EndGroups + } + targetChatID = parsedChatID + } + + if targetChatID == 0 { + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: i18n.T(i18nk.BotMsgImportErrorNoTargetChatId, nil), + }) + return dispatcher.EndGroups + } + + elems := make([]batchimport.TaskElement, 0, len(filteredFiles)) + var totalSize int64 + for _, file := range filteredFiles { + elem := batchimport.NewTaskElement(stor, file, telegramStorage, targetChatID) + elems = append(elems, *elem) + totalSize += file.Size + } + + taskID := xid.New().String() + injectCtx := tgutil.ExtWithContext(ctx.Context, ctx) + task := batchimport.NewBatchImportTask( + taskID, + injectCtx, + elems, + batchimport.NewProgressTracker(replied.ID, userID), + true, // IgnoreErrors + ) + + if err := core.AddTask(injectCtx, task); err != nil { + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: i18n.T(i18nk.BotMsgImportErrorAddTaskFailed, map[string]any{"Error": err}), + }) + return dispatcher.EndGroups + } + + ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{ + ID: replied.ID, + Message: i18n.T(i18nk.BotMsgImportInfoTaskAdded, map[string]any{ + "Count": len(elems), + "SizeMB": fmt.Sprintf("%.2f", float64(totalSize)/(1024*1024)), + "TaskID": taskID, + }), + }) + + return dispatcher.EndGroups +} diff --git a/client/bot/handlers/register.go b/client/bot/handlers/register.go index cec750a..7375702 100644 --- a/client/bot/handlers/register.go +++ b/client/bot/handlers/register.go @@ -31,6 +31,7 @@ var CommandHandlers = []DescCommandHandler{ {"dl", i18nk.BotMsgCmdDl, handleDlCmd}, {"aria2dl", i18nk.BotMsgCmdAria2dl, handleAria2DlCmd}, {"ytdlp", i18nk.BotMsgCmdYtdlp, handleYtdlpCmd}, + {"import", i18nk.BotMsgCmdImport, handleImportCmd}, {"task", i18nk.BotMsgCmdTask, handleTaskCmd}, {"cancel", i18nk.BotMsgCmdCancel, handleCancelCmd}, {"config", i18nk.BotMsgCmdConfig, handleConfigCmd}, diff --git a/common/i18n/i18nk/keys.go b/common/i18n/i18nk/keys.go index d96b4cb..29af390 100644 --- a/common/i18n/i18nk/keys.go +++ b/common/i18n/i18nk/keys.go @@ -21,6 +21,7 @@ const ( BotMsgCmdDl Key = "bot.msg.cmd.dl" BotMsgCmdFnametmpl Key = "bot.msg.cmd.fnametmpl" BotMsgCmdHelp Key = "bot.msg.cmd.help" + BotMsgCmdImport Key = "bot.msg.cmd.import" BotMsgCmdLswatch Key = "bot.msg.cmd.lswatch" BotMsgCmdParser Key = "bot.msg.cmd.parser" BotMsgCmdRule Key = "bot.msg.cmd.rule" @@ -105,6 +106,20 @@ const ( BotMsgDlInfoFilesSelectStorage Key = "bot.msg.dl.info_files_select_storage" BotMsgDlUsage Key = "bot.msg.dl.usage" BotMsgHelpTextFmt Key = "bot.msg.help_text_fmt" + BotMsgImportErrorAddTaskFailed Key = "bot.msg.import.error_add_task_failed" + BotMsgImportErrorInvalidChatId Key = "bot.msg.import.error_invalid_chat_id" + BotMsgImportErrorInvalidRegex Key = "bot.msg.import.error_invalid_regex" + BotMsgImportErrorListFilesFailed Key = "bot.msg.import.error_list_files_failed" + BotMsgImportErrorNoFilesToImport Key = "bot.msg.import.error_no_files_to_import" + BotMsgImportErrorNoTargetChatId Key = "bot.msg.import.error_no_target_chat_id" + BotMsgImportErrorNoTelegramStorage Key = "bot.msg.import.error_no_telegram_storage" + BotMsgImportErrorStorageNotFound Key = "bot.msg.import.error_storage_not_found" + BotMsgImportErrorStorageNotListable Key = "bot.msg.import.error_storage_not_listable" + BotMsgImportErrorStorageNotReadable Key = "bot.msg.import.error_storage_not_readable" + BotMsgImportInfoFetchingFiles Key = "bot.msg.import.info_fetching_files" + BotMsgImportInfoTaskAdded Key = "bot.msg.import.info_task_added" + BotMsgImportStartStats Key = "bot.msg.import.start_stats" + BotMsgImportUsage Key = "bot.msg.import.usage" BotMsgMediaGroupErrorBuildStorageSelectKeyboardFailed Key = "bot.msg.media_group.error_build_storage_select_keyboard_failed" BotMsgMediaGroupInfoGroupFoundFilesSelectStorage Key = "bot.msg.media_group.info_group_found_files_select_storage" BotMsgMediaGroupInfoSavingFiles Key = "bot.msg.media_group.info_saving_files" @@ -149,6 +164,20 @@ const ( BotMsgProgressFileProcessingPrefix Key = "bot.msg.progress.file_processing_prefix" BotMsgProgressFileSizePrefix Key = "bot.msg.progress.file_size_prefix" BotMsgProgressFileStartPrefix Key = "bot.msg.progress.file_start_prefix" + BotMsgProgressImportAvgSpeedPrefix Key = "bot.msg.progress.import_avg_speed_prefix" + BotMsgProgressImportElapsedTimePrefix Key = "bot.msg.progress.import_elapsed_time_prefix" + BotMsgProgressImportFailedFilesPrefix Key = "bot.msg.progress.import_failed_files_prefix" + BotMsgProgressImportFailedPrefix Key = "bot.msg.progress.import_failed_prefix" + BotMsgProgressImportProcessingMore Key = "bot.msg.progress.import_processing_more" + BotMsgProgressImportProcessingPrefix Key = "bot.msg.progress.import_processing_prefix" + BotMsgProgressImportProgressPrefix Key = "bot.msg.progress.import_progress_prefix" + BotMsgProgressImportRemainingTimePrefix Key = "bot.msg.progress.import_remaining_time_prefix" + BotMsgProgressImportSpeedPrefix Key = "bot.msg.progress.import_speed_prefix" + BotMsgProgressImportStartPrefix Key = "bot.msg.progress.import_start_prefix" + BotMsgProgressImportSuccessPrefix Key = "bot.msg.progress.import_success_prefix" + BotMsgProgressImportTotalFilesPrefix Key = "bot.msg.progress.import_total_files_prefix" + BotMsgProgressImportTotalSizePrefix Key = "bot.msg.progress.import_total_size_prefix" + BotMsgProgressImportUploadedPrefix Key = "bot.msg.progress.import_uploaded_prefix" BotMsgProgressParsedDonePrefix Key = "bot.msg.progress.parsed_done_prefix" BotMsgProgressParsedStartPrefix Key = "bot.msg.progress.parsed_start_prefix" BotMsgProgressProcessingListPrefix Key = "bot.msg.progress.processing_list_prefix" diff --git a/common/i18n/locale/en.yaml b/common/i18n/locale/en.yaml index 995e3e9..2cb38eb 100644 --- a/common/i18n/locale/en.yaml +++ b/common/i18n/locale/en.yaml @@ -29,6 +29,7 @@ bot: /silent - Toggle silent mode /storage - Set default storage /save [custom filename] - Save file + /import [channel_id] [filter] - Import files from storage to Telegram /dir - Manage storage directories /rule - Manage rules /config - Modify configuration @@ -52,6 +53,7 @@ bot: dl: "Download files from given links" aria2dl: "Download files using Aria2" ytdlp: "Download video/audio using yt-dlp" + import: "Import files from storage to Telegram" task: "Manage task queue" cancel: "Cancel task" watch: "Watch chats (UserBot)" @@ -294,6 +296,20 @@ bot: info_urls_select_storage: "Found {{.Count}} links, please select storage" info_downloading: "Downloading via yt-dlp..." error_download_failed: "yt-dlp download failed: {{.Error}}" + import: + usage: "Usage: /import [target_chat_id] [filter]\n\nExamples:\n/import local1 /downloads\n/import MyAlist /media/photos -1001234567890\n/import MyLocal /backup \".*[.]mp4$\"" + error_storage_not_found: "Storage '{{.StorageName}}' not found or access denied: {{.Error}}" + error_storage_not_listable: "Storage '{{.StorageName}}' does not support listing files" + error_storage_not_readable: "Storage '{{.StorageName}}' does not support reading files" + error_no_telegram_storage: "No Telegram storage found: {{.Error}}" + info_fetching_files: "Fetching file list..." + error_list_files_failed: "Failed to list files: {{.Error}}" + error_invalid_regex: "Invalid regular expression: {{.Error}}" + error_no_files_to_import: "No files to import in directory" + error_invalid_chat_id: "Invalid Chat ID: {{.Error}}" + error_no_target_chat_id: "No target channel ID specified and Telegram storage has no default chat_id configured" + error_add_task_failed: "Failed to add task: {{.Error}}" + info_task_added: "Added {{.Count}} files to import queue\nTotal size: {{.SizeMB}} MB\nTask ID: {{.TaskID}}" cancel: usage: "Usage: /cancel " error_cancel_failed: "Failed to cancel task: {{.Error}}" @@ -342,6 +358,20 @@ bot: ytdlp_done: "yt-dlp download completed and transferred ({{.Count}} files)\n" downloaded_prefix: "\nDownloaded: " current_speed_prefix: "\nCurrent speed: " + import_start_prefix: "Importing: " + import_progress_prefix: "Import progress: " + import_uploaded_prefix: "\nUploaded: " + import_speed_prefix: "\nSpeed: " + import_remaining_time_prefix: "\nRemaining time: " + import_processing_prefix: "\nProcessing:\n" + import_processing_more: "...and {{.Count}} more files\n" + import_failed_prefix: "Import failed\n" + import_success_prefix: "Import completed\n" + import_total_files_prefix: "\nTotal files: " + import_total_size_prefix: "\nTotal size: " + import_elapsed_time_prefix: "\nElapsed time: " + import_avg_speed_prefix: "\nAverage speed: " + import_failed_files_prefix: "\nFailed files: " syncpeers: start: "Starting to sync peers..." done: "Peer sync completed, total {{.Count}} chats synced" diff --git a/common/i18n/locale/zh-Hans.yaml b/common/i18n/locale/zh-Hans.yaml index c4218d6..df073b3 100644 --- a/common/i18n/locale/zh-Hans.yaml +++ b/common/i18n/locale/zh-Hans.yaml @@ -30,6 +30,7 @@ bot: /storage - 设置默认存储位置 /save [自定义文件名] - 保存文件 /dl <链接1> <链接2> ... - 下载给定链接的文件 + /import <存储名> <目录路径> [频道ID] [过滤器] - 从存储端导入文件到 Telegram /dir - 管理存储目录 /rule - 管理规则 /config - 修改配置 @@ -53,6 +54,7 @@ bot: dl: "下载给定链接的文件" aria2dl: "使用 Aria2 下载给定链接的文件" ytdlp: "使用 yt-dlp 下载视频/音频" + import: "从存储端导入文件到 Telegram" task: "管理任务队列" cancel: "取消任务" watch: "监听聊天(UserBot)" @@ -295,6 +297,26 @@ bot: info_urls_select_storage: "共 {{.Count}} 个链接, 请选择存储位置" info_downloading: "正在通过 yt-dlp 下载..." error_download_failed: "yt-dlp 下载失败: {{.Error}}" + import: + usage: | + 用法: /import [target_chat_id] [filter] + 示例: + /import 本机1 /downloads + /import MyAlist /media/photos -1001234567890 + /import MyLocal /backup ".*\.mp4$" + error_storage_not_found: "存储端 '{{.StorageName}}' 不存在或您无权访问: {{.Error}}" + error_storage_not_listable: "存储端 '{{.StorageName}}' 不支持列举文件功能" + error_storage_not_readable: "存储端 '{{.StorageName}}' 不支持读取文件功能" + error_no_telegram_storage: "未找到可用的 Telegram 存储: {{.Error}}" + info_fetching_files: "正在获取文件列表..." + error_list_files_failed: "获取文件列表失败: {{.Error}}" + error_invalid_regex: "正则表达式无效: {{.Error}}" + error_no_files_to_import: "目录中没有可导入的文件" + error_invalid_chat_id: "无效的 Chat ID: {{.Error}}" + error_no_target_chat_id: "未指定目标频道 ID,且 Telegram 存储未配置默认 chat_id" + error_add_task_failed: "添加任务失败: {{.Error}}" + info_task_added: "已添加 {{.Count}} 个文件到导入队列\n总大小: {{.SizeMB}} MB\n任务 ID: {{.TaskID}}" + start_stats: "总文件数: {{.Count}}\n总大小: {{.SizeMB}} MB" cancel: usage: "用法: /cancel " error_cancel_failed: "取消任务失败: {{.Error}}" @@ -343,6 +365,20 @@ bot: ytdlp_done: "yt-dlp 下载完成并已转存 ({{.Count}} 个文件)\n" downloaded_prefix: "\n已下载: " current_speed_prefix: "\n当前速度: " + import_start_prefix: "正在导入: " + import_progress_prefix: "导入进度: " + import_uploaded_prefix: "\n已上传: " + import_speed_prefix: "\n速度: " + import_remaining_time_prefix: "\n剩余时间: " + import_processing_prefix: "\n正在处理:\n" + import_processing_more: "...和其他 {{.Count}} 个文件\n" + import_failed_prefix: "导入失败\n" + import_success_prefix: "导入完成\n" + import_total_files_prefix: "\n总文件数: " + import_total_size_prefix: "\n总大小: " + import_elapsed_time_prefix: "\n耗时: " + import_avg_speed_prefix: "\n平均速度: " + import_failed_files_prefix: "\n失败文件数: " syncpeers: start: "正在同步对话列表..." success: "对话列表同步完成, 共同步 {{.Count}} 个对话" @@ -353,4 +389,4 @@ bot: info_adding_aria2_download: "正在添加 Aria2 下载任务..." error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}" info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}" - info_select_storage: "请选择存储位置, 选择后将添加到 Aria2 下载队列" \ No newline at end of file + info_select_storage: "请选择存储位置, 选择后将添加到 Aria2 下载队列" diff --git a/common/utils/dlutil/dl.go b/common/utils/dlutil/dl.go index 9ad027f..6d48b5f 100644 --- a/common/utils/dlutil/dl.go +++ b/common/utils/dlutil/dl.go @@ -1,6 +1,9 @@ package dlutil -import "time" +import ( + "fmt" + "time" +) var threadsLevels = []struct { threads int @@ -31,3 +34,23 @@ func GetSpeed(downloaded int64, startTime time.Time) float64 { } return float64(downloaded) / elapsed } + +// FormatSize formats a byte size as a human-readable string +func FormatSize(bytes int64) string { + const ( + KB = 1024 + MB = KB * 1024 + GB = MB * 1024 + ) + + switch { + case bytes >= GB: + return fmt.Sprintf("%.2f GB", float64(bytes)/float64(GB)) + case bytes >= MB: + return fmt.Sprintf("%.2f MB", float64(bytes)/float64(MB)) + case bytes >= KB: + return fmt.Sprintf("%.2f KB", float64(bytes)/float64(KB)) + default: + return fmt.Sprintf("%d B", bytes) + } +} diff --git a/core/tasks/batchimport/execute.go b/core/tasks/batchimport/execute.go new file mode 100644 index 0000000..b76d93f --- /dev/null +++ b/core/tasks/batchimport/execute.go @@ -0,0 +1,141 @@ +package batchimport + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" + "github.com/krau/SaveAny-Bot/storage" + "golang.org/x/sync/errgroup" +) + +// Execute implements core.Executable. +func (t *Task) Execute(ctx context.Context) error { + logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("batch_import[%s]", t.ID)) + logger.Info("Starting batch import task") + t.Progress.OnStart(ctx, t) + + workers := config.C().Workers + eg, gctx := errgroup.WithContext(ctx) + eg.SetLimit(workers) + + for _, elem := range t.elems { + eg.Go(func() error { + t.processingMu.RLock() + if t.processing[elem.ID] != nil { + t.processingMu.RUnlock() + return fmt.Errorf("element with ID %s is already being processed", elem.ID) + } + t.processingMu.RUnlock() + + t.processingMu.Lock() + t.processing[elem.ID] = &elem + t.processingMu.Unlock() + + defer func() { + t.processingMu.Lock() + delete(t.processing, elem.ID) + t.processingMu.Unlock() + }() + + err := t.processElement(gctx, elem) + if err != nil && !t.IgnoreErrors { + return err + } + if err != nil { + t.processingMu.Lock() + t.failed[elem.ID] = err + t.processingMu.Unlock() + logger.Errorf("Failed to process file %s: %v", elem.FileInfo.Name, err) + } + return nil + }) + } + + err := eg.Wait() + if err != nil { + logger.Errorf("Error during batch import processing: %v", err) + } else { + logger.Info("Batch import task completed successfully") + } + + t.Progress.OnDone(ctx, t, err) + return err +} + +func (t *Task) processElement(ctx context.Context, elem TaskElement) error { + logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.FileInfo.Name)) + + // Check whether the source storage supports reading + readableStorage, ok := elem.SourceStorage.(storage.StorageReadable) + if !ok { + return fmt.Errorf("source storage %s does not support reading", elem.SourceStorage.Name()) + } + + logger.Info("Opening file from source storage") + reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer reader.Close() + + // Build Telegram storage path: // + storagePath := fmt.Sprintf("/%d/%s", elem.TargetChatID, elem.FileInfo.Name) + + // 注入文件大小到 context + ctx = context.WithValue(ctx, ctxkey.ContentLength, size) + + if config.C().Stream { + if err := elem.TargetStorage.Save(ctx, reader, storagePath); err != nil { + return fmt.Errorf("failed to upload file to telegram: %w", err) + } + } else { + logger.Info("Downloading to temporary file for ReadSeeker support") + tempFile, err := t.downloadToTemp(reader, elem.FileInfo.Name) + if err != nil { + return fmt.Errorf("failed to download to temp: %w", err) + } + defer os.Remove(tempFile.Name()) + defer tempFile.Close() + + if _, err := tempFile.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek temp file: %w", err) + } + + logger.Infof("Uploading file to Telegram storage (size: %d bytes)", size) + if err := elem.TargetStorage.Save(ctx, tempFile, storagePath); err != nil { + return fmt.Errorf("failed to upload file to telegram: %w", err) + } + } + + t.uploaded.Add(size) + t.Progress.OnProgress(ctx, t) + + logger.Info("File uploaded successfully") + return nil +} + +func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, error) { + tempDir := config.C().Temp.BasePath + if tempDir == "" { + tempDir = os.TempDir() + } + + tempFile, err := os.CreateTemp(tempDir, filepath.Base(filename)+"-*.tmp") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + + if _, err := io.Copy(tempFile, reader); err != nil { + tempFile.Close() + os.Remove(tempFile.Name()) + return nil, fmt.Errorf("failed to copy to temp file: %w", err) + } + + return tempFile, nil +} diff --git a/core/tasks/batchimport/progress.go b/core/tasks/batchimport/progress.go new file mode 100644 index 0000000..fe89b02 --- /dev/null +++ b/core/tasks/batchimport/progress.go @@ -0,0 +1,244 @@ +package batchimport + +import ( + "context" + "fmt" + "strings" + "sync/atomic" + "time" + + "github.com/charmbracelet/log" + "github.com/gotd/td/telegram/message/entity" + "github.com/gotd/td/telegram/message/styling" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/common/i18n/i18nk" + "github.com/krau/SaveAny-Bot/common/utils/dlutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" +) + +type ProgressTracker interface { + OnStart(ctx context.Context, info TaskInfo) + OnProgress(ctx context.Context, info TaskInfo) + OnDone(ctx context.Context, info TaskInfo, err error) +} + +type Progress struct { + MessageID int + ChatID int64 + start time.Time + lastUpdatePercent atomic.Int32 +} + +func NewProgressTracker(messageID int, chatID int64) ProgressTracker { + return &Progress{ + MessageID: messageID, + ChatID: chatID, + } +} + +func (p *Progress) OnStart(ctx context.Context, info TaskInfo) { + p.start = time.Now() + p.lastUpdatePercent.Store(0) + log.FromContext(ctx).Debugf("Batch import task progress tracking started for message %d in chat %d", p.MessageID, p.ChatID) + + sizeMB := float64(info.TotalSize()) / (1024 * 1024) + statsText := i18n.T(i18nk.BotMsgImportStartStats, map[string]any{ + "SizeMB": fmt.Sprintf("%.2f", sizeMB), + "Count": info.Count(), + }) + + entityBuilder := entity.Builder{} + if err := styling.Perform(&entityBuilder, + styling.Plain(i18n.T(i18nk.BotMsgProgressImportStartPrefix, nil)), + styling.Code(statsText), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }, + }) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + } +} + +func (p *Progress) OnProgress(ctx context.Context, info TaskInfo) { + if !shouldUpdateProgress(info.TotalSize(), info.Uploaded(), int(p.lastUpdatePercent.Load())) { + return + } + percent := int((info.Uploaded() * 100) / info.TotalSize()) + if p.lastUpdatePercent.Load() == int32(percent) { + return + } + p.lastUpdatePercent.Store(int32(percent)) + + log.FromContext(ctx).Debugf("Progress update: %s, %d/%d", info.TaskID(), info.Uploaded(), info.TotalSize()) + + entityBuilder := entity.Builder{} + var progressText strings.Builder + + progressText.WriteString(i18n.T(i18nk.BotMsgProgressImportProgressPrefix, nil)) + progressText.WriteString(fmt.Sprintf("%d%%", percent)) + progressText.WriteString(i18n.T(i18nk.BotMsgProgressImportUploadedPrefix, nil)) + progressText.WriteString(fmt.Sprintf("%.2f MB / %.2f MB", + float64(info.Uploaded())/(1024*1024), + float64(info.TotalSize())/(1024*1024))) + + if p.start.Unix() > 0 { + elapsed := time.Since(p.start) + speed := float64(info.Uploaded()) / elapsed.Seconds() + progressText.WriteString(i18n.T(i18nk.BotMsgProgressImportSpeedPrefix, nil)) + progressText.WriteString(dlutil.FormatSize(int64(speed)) + "/s") + + if info.Uploaded() > 0 { + remaining := time.Duration(float64(info.TotalSize()-info.Uploaded()) / speed * float64(time.Second)) + progressText.WriteString(i18n.T(i18nk.BotMsgProgressImportRemainingTimePrefix, nil)) + progressText.WriteString(formatDuration(remaining)) + } + } + + processing := info.Processing() + if len(processing) > 0 { + progressText.WriteString(i18n.T(i18nk.BotMsgProgressImportProcessingPrefix, nil)) + for i, elem := range processing { + if i >= 3 { + progressText.WriteString(i18n.T(i18nk.BotMsgProgressImportProcessingMore, map[string]any{"Count": len(processing) - 3})) + break + } + fmt.Fprintf(&progressText, "- %s\n", elem.FileName()) + } + } + + if err := styling.Perform(&entityBuilder, + styling.Plain(progressText.String()), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + req.SetReplyMarkup(&tg.ReplyInlineMarkup{ + Rows: []tg.KeyboardButtonRow{ + { + Buttons: []tg.KeyboardButtonClass{ + tgutil.BuildCancelButton(info.TaskID()), + }, + }, + }, + }) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + } +} + +func (p *Progress) OnDone(ctx context.Context, info TaskInfo, err error) { + log.FromContext(ctx).Debugf("Batch import task progress tracking done for message %d in chat %d", p.MessageID, p.ChatID) + + entityBuilder := entity.Builder{} + var resultText strings.Builder + + if err != nil { + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportFailedPrefix, nil)) + resultText.WriteString(i18n.T(i18nk.BotMsgProgressErrorPrefix, nil)) + fmt.Fprintf(&resultText, "%v\n", err) + } else { + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportSuccessPrefix, nil)) + } + + elapsed := time.Since(p.start) + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportTotalFilesPrefix, nil)) + fmt.Fprintf(&resultText, "%d\n", info.Count()) + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportTotalSizePrefix, nil)) + fmt.Fprintf(&resultText, "%.2f MB\n", float64(info.TotalSize())/(1024*1024)) + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportUploadedPrefix, nil)) + fmt.Fprintf(&resultText, "%.2f MB\n", float64(info.Uploaded())/(1024*1024)) + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportElapsedTimePrefix, nil)) + fmt.Fprintf(&resultText, "%s\n", formatDuration(elapsed)) + + if elapsed.Seconds() > 0 { + avgSpeed := float64(info.Uploaded()) / elapsed.Seconds() + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportAvgSpeedPrefix, nil)) + fmt.Fprintf(&resultText, "%s/s\n", dlutil.FormatSize(int64(avgSpeed))) + } + + failedFiles := info.FailedFiles() + if len(failedFiles) > 0 { + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportFailedFilesPrefix, nil)) + fmt.Fprintf(&resultText, "%d\n", len(failedFiles)) + for i, name := range failedFiles { + if i >= 5 { + resultText.WriteString(i18n.T(i18nk.BotMsgProgressImportProcessingMore, map[string]any{"Count": len(failedFiles) - 5})) + break + } + fmt.Fprintf(&resultText, "- %s\n", name) + } + } + + if err := styling.Perform(&entityBuilder, + styling.Plain(resultText.String()), + ); err != nil { + log.FromContext(ctx).Errorf("Failed to build entities: %s", err) + return + } + + text, entities := entityBuilder.Complete() + req := &tg.MessagesEditMessageRequest{ + ID: p.MessageID, + } + req.SetMessage(text) + req.SetEntities(entities) + + ext := tgutil.ExtFromContext(ctx) + if ext != nil { + ext.EditMessage(p.ChatID, req) + } +} + +func shouldUpdateProgress(total, current int64, lastPercent int) bool { + if total == 0 { + return false + } + currentPercent := int((current * 100) / total) + return currentPercent > lastPercent && currentPercent%5 == 0 +} + +func formatDuration(d time.Duration) string { + d = d.Round(time.Second) + h := d / time.Hour + d -= h * time.Hour + m := d / time.Minute + d -= m * time.Minute + s := d / time.Second + + if h > 0 { + return fmt.Sprintf("%dh%dm%ds", h, m, s) + } + if m > 0 { + return fmt.Sprintf("%dm%ds", m, s) + } + return fmt.Sprintf("%ds", s) +} diff --git a/core/tasks/batchimport/task.go b/core/tasks/batchimport/task.go new file mode 100644 index 0000000..0536745 --- /dev/null +++ b/core/tasks/batchimport/task.go @@ -0,0 +1,97 @@ +package batchimport + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/storagetypes" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +var _ core.Executable = (*Task)(nil) + +type TaskElement struct { + ID string + SourceStorage storage.Storage + SourcePath string + FileInfo storagetypes.FileInfo + TargetStorage storage.Storage + TargetChatID int64 +} + +type Task struct { + ID string + ctx context.Context + elems []TaskElement + Progress ProgressTracker + IgnoreErrors bool + uploaded atomic.Int64 + totalSize int64 + processing map[string]TaskElementInfo + processingMu sync.RWMutex + failed map[string]error +} + +// Title implements core.Executable. +func (t *Task) Title() string { + return fmt.Sprintf("[%s](%d files/%.2fMB)", t.Type(), len(t.elems), float64(t.totalSize)/(1024*1024)) +} + +// Type implements core.Executable. +func (t *Task) Type() tasktype.TaskType { + return tasktype.TaskTypeBatchimport +} + +// TaskID implements core.Executable. +func (t *Task) TaskID() string { + return t.ID +} + +func NewTaskElement( + sourceStorage storage.Storage, + fileInfo storagetypes.FileInfo, + targetStorage storage.Storage, + targetChatID int64, +) *TaskElement { + id := xid.New().String() + return &TaskElement{ + ID: id, + SourceStorage: sourceStorage, + SourcePath: fileInfo.Path, + FileInfo: fileInfo, + TargetStorage: targetStorage, + TargetChatID: targetChatID, + } +} + +func NewBatchImportTask( + id string, + ctx context.Context, + elems []TaskElement, + progress ProgressTracker, + ignoreErrors bool, +) *Task { + task := &Task{ + ID: id, + ctx: ctx, + elems: elems, + Progress: progress, + uploaded: atomic.Int64{}, + totalSize: func() int64 { + var total int64 + for _, elem := range elems { + total += elem.FileInfo.Size + } + return total + }(), + processing: make(map[string]TaskElementInfo), + IgnoreErrors: ignoreErrors, + failed: make(map[string]error), + } + return task +} diff --git a/core/tasks/batchimport/taskinfo.go b/core/tasks/batchimport/taskinfo.go new file mode 100644 index 0000000..a622c74 --- /dev/null +++ b/core/tasks/batchimport/taskinfo.go @@ -0,0 +1,73 @@ +package batchimport + +type TaskElementInfo interface { + FileName() string + FileSize() int64 + GetSourcePath() string + SourceStorageName() string +} + +func (e *TaskElement) FileName() string { + return e.FileInfo.Name +} + +func (e *TaskElement) FileSize() int64 { + return e.FileInfo.Size +} + +func (e *TaskElement) GetSourcePath() string { + return e.SourcePath +} + +func (e *TaskElement) SourceStorageName() string { + return e.SourceStorage.Name() +} + +type TaskInfo interface { + TaskID() string + TotalSize() int64 + Uploaded() int64 + Count() int + Processing() []TaskElementInfo + FailedFiles() []string +} + +func (t *Task) TotalSize() int64 { + return t.totalSize +} + +func (t *Task) Uploaded() int64 { + return t.uploaded.Load() +} + +func (t *Task) Count() int { + return len(t.elems) +} + +func (t *Task) Processing() []TaskElementInfo { + t.processingMu.RLock() + defer t.processingMu.RUnlock() + + result := make([]TaskElementInfo, 0, len(t.processing)) + for _, elem := range t.processing { + result = append(result, elem) + } + return result +} + +func (t *Task) FailedFiles() []string { + t.processingMu.RLock() + defer t.processingMu.RUnlock() + + result := make([]string, 0, len(t.failed)) + for id := range t.failed { + // Find the element by ID + for _, elem := range t.elems { + if elem.ID == id { + result = append(result, elem.FileInfo.Name) + break + } + } + } + return result +} diff --git a/pkg/enums/tasktype/tasktype.go b/pkg/enums/tasktype/tasktype.go index f1248c6..d8ba420 100644 --- a/pkg/enums/tasktype/tasktype.go +++ b/pkg/enums/tasktype/tasktype.go @@ -1,6 +1,5 @@ package tasktype -// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp) -// //go:generate go-enum --values --names --flag --nocase +// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp,batchimport) type TaskType string diff --git a/pkg/enums/tasktype/tasktype_enum.go b/pkg/enums/tasktype/tasktype_enum.go index 83b34a0..5cd4e27 100644 --- a/pkg/enums/tasktype/tasktype_enum.go +++ b/pkg/enums/tasktype/tasktype_enum.go @@ -24,6 +24,8 @@ const ( TaskTypeAria2 TaskType = "aria2" // TaskTypeYtdlp is a TaskType of type ytdlp. TaskTypeYtdlp TaskType = "ytdlp" + // TaskTypeBatchimport is a TaskType of type batchimport. + TaskTypeBatchimport TaskType = "batchimport" ) var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", ")) @@ -35,6 +37,7 @@ var _TaskTypeNames = []string{ string(TaskTypeDirectlinks), string(TaskTypeAria2), string(TaskTypeYtdlp), + string(TaskTypeBatchimport), } // TaskTypeNames returns a list of possible string values of TaskType. @@ -53,6 +56,7 @@ func TaskTypeValues() []TaskType { TaskTypeDirectlinks, TaskTypeAria2, TaskTypeYtdlp, + TaskTypeBatchimport, } } @@ -75,6 +79,7 @@ var _TaskTypeValue = map[string]TaskType{ "directlinks": TaskTypeDirectlinks, "aria2": TaskTypeAria2, "ytdlp": TaskTypeYtdlp, + "batchimport": TaskTypeBatchimport, } // ParseTaskType attempts to convert a string to a TaskType. diff --git a/pkg/storagetypes/fileinfo.go b/pkg/storagetypes/fileinfo.go new file mode 100644 index 0000000..6af6ab4 --- /dev/null +++ b/pkg/storagetypes/fileinfo.go @@ -0,0 +1,12 @@ +package storagetypes + +import "time" + +// FileInfo represents file metadata +type FileInfo struct { + Name string + Path string + Size int64 + IsDir bool + ModTime time.Time +} diff --git a/storage/alist/alist.go b/storage/alist/alist.go index d0438d1..f9fa875 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -16,6 +16,7 @@ import ( config "github.com/krau/SaveAny-Bot/config/storage" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/krau/SaveAny-Bot/pkg/storagetypes" ) type Alist struct { @@ -215,3 +216,156 @@ func (a *Alist) Exists(ctx context.Context, storagePath string) bool { func (a *Alist) CannotStream() string { return "Alist does not support chunked transfer encoding" } + +// ListFiles implements StorageListable interface +func (a *Alist) ListFiles(ctx context.Context, dirPath string) ([]storagetypes.FileInfo, error) { + a.logger.Debugf("Listing files in directory: %s", dirPath) + + reqBody := fsListRequest{ + Path: dirPath, + Password: "", + Page: 1, + PerPage: 0, // 0 means all files + Refresh: false, + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.baseURL+"/api/fs/list", bytes.NewBuffer(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", a.token) + req.Header.Set("Content-Type", "application/json") + + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to list files: %s", resp.Status) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var listResp fsListResponse + if err := json.Unmarshal(data, &listResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal list response: %w", err) + } + + if listResp.Code != http.StatusOK { + return nil, fmt.Errorf("failed to list files: %d, %s", listResp.Code, listResp.Message) + } + + files := make([]storagetypes.FileInfo, 0, len(listResp.Data.Content)) + for _, item := range listResp.Data.Content { + // Parse modified time; log failures but keep zero value on error. + var modTime time.Time + if item.Modified != "" { + parsedTime, err := time.Parse(time.RFC3339, item.Modified) + if err != nil { + a.logger.With( + "path", path.Join(dirPath, item.Name), + "modified_raw", item.Modified, + ).Warnf("failed to parse modified time for file") + } else { + modTime = parsedTime + } + } + + files = append(files, storagetypes.FileInfo{ + Name: item.Name, + Path: path.Join(dirPath, item.Name), + Size: item.Size, + IsDir: item.IsDir, + ModTime: modTime, + }) + } + + a.logger.Debugf("Found %d files in directory %s", len(files), dirPath) + return files, nil +} + +// OpenFile implements StorageReadable interface +func (a *Alist) OpenFile(ctx context.Context, filePath string) (io.ReadCloser, int64, error) { + a.logger.Debugf("Opening file: %s", filePath) + + // First, get file info to get the raw_url + reqBody := map[string]any{ + "path": filePath, + "password": "", + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, 0, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.baseURL+"/api/fs/get", bytes.NewBuffer(bodyBytes)) + if err != nil { + return nil, 0, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", a.token) + req.Header.Set("Content-Type", "application/json") + + resp, err := a.client.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("failed to get file info: %s", resp.Status) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, 0, fmt.Errorf("failed to read response body: %w", err) + } + + var getResp fsGetResponse + if err := json.Unmarshal(data, &getResp); err != nil { + return nil, 0, fmt.Errorf("failed to unmarshal get response: %w", err) + } + + if getResp.Code != http.StatusOK { + return nil, 0, fmt.Errorf("failed to get file info: %d, %s", getResp.Code, getResp.Message) + } + + if getResp.Data.IsDir { + return nil, 0, fmt.Errorf("path is a directory, not a file") + } + + // Download the file from raw_url + downloadURL := getResp.Data.RawURL + if downloadURL == "" { + // If no raw_url, construct download URL + downloadURL = a.baseURL + "/d" + filePath + } + + downloadReq, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) + if err != nil { + return nil, 0, fmt.Errorf("failed to create download request: %w", err) + } + + downloadResp, err := a.client.Do(downloadReq) + if err != nil { + return nil, 0, fmt.Errorf("failed to download file: %w", err) + } + + if downloadResp.StatusCode != http.StatusOK { + downloadResp.Body.Close() + return nil, 0, fmt.Errorf("failed to download file: %s", downloadResp.Status) + } + + a.logger.Debugf("Opened file %s, size: %d bytes", filePath, getResp.Data.Size) + return downloadResp.Body, getResp.Data.Size, nil +} diff --git a/storage/alist/types.go b/storage/alist/types.go index 59be4d5..c3d28e8 100644 --- a/storage/alist/types.go +++ b/storage/alist/types.go @@ -46,4 +46,46 @@ type putResponse struct { type fsGetResponse struct { Code int `json:"code"` Message string `json:"message"` + Data struct { + Name string `json:"name"` + Size int64 `json:"size"` + IsDir bool `json:"is_dir"` + Modified string `json:"modified"` + Created string `json:"created"` + Sign string `json:"sign"` + Thumb string `json:"thumb"` + Type int `json:"type"` + RawURL string `json:"raw_url"` + Provider string `json:"provider"` + } `json:"data"` +} + +type fsListRequest struct { + Path string `json:"path"` + Password string `json:"password"` + Page int `json:"page"` + PerPage int `json:"per_page"` + Refresh bool `json:"refresh"` +} + +type fsListResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + Content []struct { + Name string `json:"name"` + Size int64 `json:"size"` + IsDir bool `json:"is_dir"` + Modified string `json:"modified"` + Created string `json:"created"` + Sign string `json:"sign"` + Thumb string `json:"thumb"` + Type int `json:"type"` + } `json:"content"` + Total int64 `json:"total"` + Readme string `json:"readme"` + Header string `json:"header"` + Write bool `json:"write"` + Provider string `json:"provider"` + } `json:"data"` } diff --git a/storage/load.go b/storage/load.go index 3a926e6..09bf4bf 100644 --- a/storage/load.go +++ b/storage/load.go @@ -6,6 +6,7 @@ import ( "github.com/charmbracelet/log" "github.com/krau/SaveAny-Bot/config" + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" ) var UserStorages = make(map[int64][]Storage) @@ -79,3 +80,14 @@ func LoadStorages(ctx context.Context) { UserStorages[int64(user)] = GetUserStorages(ctx, int64(user)) } } + +// GetTelegramStorageByUserID returns the first enabled Telegram storage for the user +func GetTelegramStorageByUserID(ctx context.Context, chatID int64) (Storage, error) { + storages := GetUserStorages(ctx, chatID) + for _, stor := range storages { + if stor.Type() == storenum.Telegram { + return stor, nil + } + } + return nil, fmt.Errorf("no telegram storage found for user %d", chatID) +} diff --git a/storage/local/local.go b/storage/local/local.go index 34ed4be..432f8af 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -12,6 +12,7 @@ import ( "github.com/duke-git/lancet/v2/fileutil" config "github.com/krau/SaveAny-Bot/config/storage" storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/krau/SaveAny-Bot/pkg/storagetypes" ) type Local struct { @@ -81,3 +82,51 @@ func (l *Local) Exists(ctx context.Context, storagePath string) bool { } return fileutil.IsExist(absPath) } + +// ListFiles implements StorageListable interface +func (l *Local) ListFiles(ctx context.Context, dirPath string) ([]storagetypes.FileInfo, error) { + absPath := l.JoinStoragePath(dirPath) + + entries, err := os.ReadDir(absPath) + if err != nil { + return nil, fmt.Errorf("failed to read directory %s: %w", absPath, err) + } + + files := make([]storagetypes.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + l.logger.Warnf("Failed to get file info for %s: %v", entry.Name(), err) + continue + } + + filePath := filepath.Join(dirPath, entry.Name()) + files = append(files, storagetypes.FileInfo{ + Name: entry.Name(), + Path: filePath, + Size: info.Size(), + IsDir: entry.IsDir(), + ModTime: info.ModTime(), + }) + } + + return files, nil +} + +// OpenFile implements StorageReadable interface +func (l *Local) OpenFile(ctx context.Context, filePath string) (io.ReadCloser, int64, error) { + absPath := l.JoinStoragePath(filePath) + + file, err := os.Open(absPath) + if err != nil { + return nil, 0, fmt.Errorf("failed to open file %s: %w", absPath, err) + } + + stat, err := file.Stat() + if err != nil { + file.Close() + return nil, 0, fmt.Errorf("failed to stat file %s: %w", absPath, err) + } + + return file, stat.Size(), nil +} diff --git a/storage/storage.go b/storage/storage.go index 0928d53..7283b83 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -7,6 +7,7 @@ import ( storcfg "github.com/krau/SaveAny-Bot/config/storage" storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/krau/SaveAny-Bot/pkg/storagetypes" "github.com/krau/SaveAny-Bot/storage/alist" "github.com/krau/SaveAny-Bot/storage/local" "github.com/krau/SaveAny-Bot/storage/minio" @@ -30,6 +31,18 @@ type StorageCannotStream interface { CannotStream() string } +// StorageListable 表示支持列举目录内容的存储 +type StorageListable interface { + Storage + ListFiles(ctx context.Context, dirPath string) ([]storagetypes.FileInfo, error) +} + +// StorageReadable 表示支持读取文件内容的存储 +type StorageReadable interface { + Storage + OpenFile(ctx context.Context, filePath string) (io.ReadCloser, int64, error) +} + var Storages = make(map[string]Storage) type StorageConstructor func() Storage diff --git a/storage/telegram/split.go b/storage/telegram/split.go index 50091be..cdb8e34 100644 --- a/storage/telegram/split.go +++ b/storage/telegram/split.go @@ -99,12 +99,6 @@ func (w *splitWriter) finalize() error { } func CreateSplitZip(ctx context.Context, reader io.Reader, size int64, fileName, outputBase string, partSize int64) error { - // seek the reader if possible - if rs, ok := reader.(io.ReadSeeker); ok { - if _, err := rs.Seek(0, io.SeekStart); err != nil { - return fmt.Errorf("failed to seek reader: %w", err) - } - } outputDir := filepath.Dir(outputBase) if err := os.MkdirAll(outputDir, os.ModePerm); err != nil { return fmt.Errorf("failed to create output directory: %w", err) diff --git a/storage/telegram/telegram.go b/storage/telegram/telegram.go index c7f1da9..8bb4220 100644 --- a/storage/telegram/telegram.go +++ b/storage/telegram/telegram.go @@ -92,9 +92,6 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er return nil } rs, seekable := r.(io.ReadSeeker) - if !seekable || rs == nil { - return fmt.Errorf("reader must implement io.ReadSeeker") - } splitSize := t.config.SplitSizeMB * 1024 * 1024 if splitSize <= 0 { splitSize = DefaultSplitSize @@ -123,88 +120,96 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er } chatID = cid } - mtype, err := mimetype.DetectReader(rs) - if err != nil { - return fmt.Errorf("failed to detect mimetype: %w", err) - } - if filename == "" { - filename = xid.New().String() + mtype.Extension() - } + upler := uploader.NewUploader(tctx.Raw). + WithPartSize(tglimit.MaxUploadPartSize). + WithThreads(dlutil.BestThreads(size, config.C().Threads)) peer := tryGetInputPeer(tctx, chatID) if peer == nil || peer.Zero() { return fmt.Errorf("failed to get input peer for chat ID %d", chatID) } + var mtype *mimetype.MIME + if seekable { + var err error + mtype, err = mimetype.DetectReader(rs) + if err != nil { + return fmt.Errorf("failed to detect mimetype: %w", err) + } + if filename == "" { + filename = xid.New().String() + mtype.Extension() + } - if _, err := rs.Seek(0, io.SeekStart); err != nil { - return fmt.Errorf("failed to seek reader: %w", err) + if _, err := rs.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek reader: %w", err) + } } - upler := uploader.NewUploader(tctx.Raw). - WithPartSize(tglimit.MaxUploadPartSize). - WithThreads(dlutil.BestThreads(size, config.C().Threads)) if size > splitSize { // large file, use split uploader - return t.splitUpload(tctx, rs, filename, upler, peer, size, splitSize) + return t.splitUpload(tctx, r, filename, upler, peer, size, splitSize) } var file tg.InputFileClass - if size < 0 { - file, err = upler.FromReader(ctx, filename, rs) + var err error + if size <= 0 { + file, err = upler.FromReader(ctx, filename, r) } else { - file, err = upler.Upload(ctx, uploader.NewUpload(filename, rs, size)) + file, err = upler.Upload(ctx, uploader.NewUpload(filename, r, size)) } if err != nil { return fmt.Errorf("failed to upload file to telegram: %w", err) } caption := styling.Plain(filename) forceFile := t.config.ForceFile - if strings.HasPrefix(mtype.String(), "image/") && size >= tglimit.MaxPhotoSize { + + if mtype != nil && strings.HasPrefix(mtype.String(), "image/") && size >= tglimit.MaxPhotoSize { forceFile = true } doc := message.UploadedDocument(file, caption). Filename(filename). - ForceFile(forceFile). - MIME(mtype.String()) - + ForceFile(forceFile) + if mtype != nil { + doc = doc.MIME(mtype.String()) + } var media message.MediaOption = doc - - switch mtypeStr := mtype.String(); { - case strings.HasPrefix(mtypeStr, "video/"): - media = doc.Video().SupportsStreaming() - thumb, err := extractThumbFrame(rs) - if err == nil { - thumb, err := upler.FromBytes(ctx, "thumb.jpg", thumb) + if mtype != nil && rs != nil { + switch mtypeStr := mtype.String(); { + case strings.HasPrefix(mtypeStr, "video/"): + media = doc.Video().SupportsStreaming() + thumb, err := extractThumbFrame(rs) if err == nil { - doc = doc.Thumb(thumb) + thumb, err := upler.FromBytes(ctx, "thumb.jpg", thumb) + if err == nil { + doc = doc.Thumb(thumb) + } } + rs.Seek(0, io.SeekStart) + switch mtypeStr { + case "video/mp4": + info, err := getMP4Meta(rs) + if err != nil { + // Fallback to ffprobe if gomedia fails (e.g., malformed MP4) + rs.Seek(0, io.SeekStart) + info, err = getVideoMetadata(rs) + } + if err == nil { + media = doc.Video(). + Duration(time.Duration(info.Duration)*time.Second). + Resolution(info.Width, info.Height). + SupportsStreaming() + } + default: + info, err := getVideoMetadata(rs) + if err == nil { + media = doc.Video(). + Duration(time.Duration(info.Duration)*time.Second). + Resolution(info.Width, info.Height). + SupportsStreaming() + } + } + case strings.HasPrefix(mtypeStr, "audio/"): + media = doc.Audio().Title(filename) + case strings.HasPrefix(mtypeStr, "image/") && !strings.HasSuffix(mtypeStr, "webp"): + media = message.UploadedPhoto(file, caption) } - rs.Seek(0, io.SeekStart) - switch mtypeStr { - case "video/mp4": - info, err := getMP4Meta(rs) - if err != nil { - // Fallback to ffprobe if gomedia fails (e.g., malformed MP4) - rs.Seek(0, io.SeekStart) - info, err = getVideoMetadata(rs) - } - if err == nil { - media = doc.Video(). - Duration(time.Duration(info.Duration)*time.Second). - Resolution(info.Width, info.Height). - SupportsStreaming() - } - default: - info, err := getVideoMetadata(rs) - if err == nil { - media = doc.Video(). - Duration(time.Duration(info.Duration)*time.Second). - Resolution(info.Width, info.Height). - SupportsStreaming() - } - } - case strings.HasPrefix(mtypeStr, "audio/"): - media = doc.Audio().Title(filename) - case strings.HasPrefix(mtypeStr, "image/") && !strings.HasSuffix(mtypeStr, "webp"): - media = message.UploadedPhoto(file, caption) } sender := tctx.Sender _, err = sender.WithUploader(upler).To(peer).Media(ctx, media) @@ -215,7 +220,7 @@ func (t *Telegram) CannotStream() string { return "Telegram storage must use a ReaderSeeker" } -func (t *Telegram) splitUpload(ctx *ext.Context, rs io.ReadSeeker, filename string, upler *uploader.Uploader, peer tg.InputPeerClass, fileSize, splitSize int64) error { +func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, upler *uploader.Uploader, peer tg.InputPeerClass, fileSize, splitSize int64) error { tempId := xid.New().String() outputBase := filepath.Join(config.C().Temp.BasePath, tempId, strings.Split(filename, ".")[0]) defer func() { @@ -224,7 +229,7 @@ func (t *Telegram) splitUpload(ctx *ext.Context, rs io.ReadSeeker, filename stri log.FromContext(ctx).Warnf("Failed to cleanup temp split files: %s", err) } }() - if err := CreateSplitZip(ctx, rs, fileSize, filename, outputBase, splitSize); err != nil { + if err := CreateSplitZip(ctx, r, fileSize, filename, outputBase, splitSize); err != nil { return fmt.Errorf("failed to create split zip: %w", err) } matched, err := filepath.Glob(outputBase + ".z*") diff --git a/storage/webdav/client.go b/storage/webdav/client.go index e16abc1..976c7aa 100644 --- a/storage/webdav/client.go +++ b/storage/webdav/client.go @@ -2,6 +2,7 @@ package webdav import ( "context" + "encoding/xml" "fmt" "io" "net/http" @@ -25,8 +26,40 @@ const ( WebdavMethodMkcol WebdavMethod = "MKCOL" WebdavMethodPropfind WebdavMethod = "PROPFIND" WebdavMethodPut WebdavMethod = "PUT" + WebdavMethodGet WebdavMethod = "GET" ) +// WebDAV XML structures for PROPFIND response +type Multistatus struct { + XMLName xml.Name `xml:"multistatus"` + Responses []Response `xml:"response"` +} + +type Response struct { + Href string `xml:"href"` + Propstat Propstat `xml:"propstat"` +} + +type Propstat struct { + Prop Prop `xml:"prop"` + Status string `xml:"status"` +} + +type Prop struct { + ResourceType ResourceType `xml:"resourcetype"` + GetContentLength int64 `xml:"getcontentlength"` + GetLastModified string `xml:"getlastmodified"` + DisplayName string `xml:"displayname"` +} + +type ResourceType struct { + Collection *struct{} `xml:"collection"` +} + +func (rt ResourceType) IsCollection() bool { + return rt.Collection != nil +} + func NewClient(baseURL, username, password string, httpClient *http.Client) *Client { if !strings.HasSuffix(baseURL, "/") { baseURL += "/" @@ -131,5 +164,79 @@ func (c *Client) WriteFile(ctx context.Context, remotePath string, content io.Re return nil } return fmt.Errorf("PUT: %s", resp.Status) - +} + +// ListDir lists files and directories in the given path +func (c *Client) ListDir(ctx context.Context, dirPath string) ([]Response, error) { + dirPath = strings.Trim(dirPath, "/") + u, err := url.Parse(c.BaseURL) + if err != nil { + return nil, err + } + u.Path = path.Join(u.Path, dirPath) + if !strings.HasSuffix(u.Path, "/") { + u.Path += "/" + } + + resp, err := c.doRequest(ctx, WebdavMethodPropfind, u.String(), nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMultiStatus { + return nil, fmt.Errorf("PROPFIND: %s", resp.Status) + } + + var multistatus Multistatus + if err := xml.NewDecoder(resp.Body).Decode(&multistatus); err != nil { + return nil, fmt.Errorf("failed to decode PROPFIND response: %w", err) + } + + // Filter out the directory itself from results + var results []Response + basePath := u.Path + for _, r := range multistatus.Responses { + decodedHref, err := url.PathUnescape(r.Href) + if err != nil { + decodedHref = r.Href + } + // Skip the directory itself + if strings.TrimSuffix(decodedHref, "/") == strings.TrimSuffix(basePath, "/") { + continue + } + results = append(results, r) + } + + return results, nil +} + +// ReadFile downloads a file and returns a ReadCloser +func (c *Client) ReadFile(ctx context.Context, filePath string) (io.ReadCloser, int64, error) { + filePath = strings.Trim(filePath, "/") + u, err := url.Parse(c.BaseURL) + if err != nil { + return nil, 0, err + } + u.Path = path.Join(u.Path, filePath) + + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, 0, err + } + if c.Username != "" && c.Password != "" { + req.SetBasicAuth(c.Username, c.Password) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, 0, err + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, 0, fmt.Errorf("GET: %s", resp.Status) + } + + return resp.Body, resp.ContentLength, nil } diff --git a/storage/webdav/webdav.go b/storage/webdav/webdav.go index a887de8..c16430e 100644 --- a/storage/webdav/webdav.go +++ b/storage/webdav/webdav.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "net/url" "path" "strings" "time" @@ -12,6 +13,7 @@ import ( "github.com/charmbracelet/log" config "github.com/krau/SaveAny-Bot/config/storage" storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/krau/SaveAny-Bot/pkg/storagetypes" "github.com/rs/xid" ) @@ -84,3 +86,77 @@ func (w *Webdav) Exists(ctx context.Context, storagePath string) bool { } return exists } + +// ListFiles implements storage.StorageListable +func (w *Webdav) ListFiles(ctx context.Context, dirPath string) ([]storagetypes.FileInfo, error) { + w.logger.Infof("Listing files in %s", dirPath) + + // Join with base path + fullPath := path.Join(w.config.BasePath, dirPath) + + responses, err := w.client.ListDir(ctx, fullPath) + if err != nil { + w.logger.Errorf("Failed to list directory %s: %v", fullPath, err) + return nil, fmt.Errorf("failed to list directory: %w", err) + } + + files := make([]storagetypes.FileInfo, 0, len(responses)) + for _, resp := range responses { + // Parse the href to get the file name + decodedHref, err := url.PathUnescape(resp.Href) + if err != nil { + w.logger.Warnf("Failed to unescape href %q: %v; using original value", resp.Href, err) + decodedHref = resp.Href + } + + // Extract filename from href + name := path.Base(strings.TrimSuffix(decodedHref, "/")) + if name == "" || name == "." { + continue + } + + // Parse modification time + var modTime time.Time + if resp.Propstat.Prop.GetLastModified != "" { + // Try RFC1123 format (standard for WebDAV) + parsedTime, err := time.Parse(time.RFC1123, resp.Propstat.Prop.GetLastModified) + if err != nil { + w.logger.Warnf("Failed to parse last modified time %q for %s: %v", resp.Propstat.Prop.GetLastModified, decodedHref, err) + } else { + modTime = parsedTime + } + } + + isDir := resp.Propstat.Prop.ResourceType.IsCollection() + + fileInfo := storagetypes.FileInfo{ + Name: name, + Path: strings.TrimPrefix(decodedHref, w.config.BasePath), + Size: resp.Propstat.Prop.GetContentLength, + IsDir: isDir, + ModTime: modTime, + } + + files = append(files, fileInfo) + } + + w.logger.Debugf("Found %d files/directories in %s", len(files), dirPath) + return files, nil +} + +// OpenFile implements storage.StorageReadable +func (w *Webdav) OpenFile(ctx context.Context, filePath string) (io.ReadCloser, int64, error) { + w.logger.Infof("Opening file %s", filePath) + + // Join with base path + fullPath := path.Join(w.config.BasePath, filePath) + + reader, size, err := w.client.ReadFile(ctx, fullPath) + if err != nil { + w.logger.Errorf("Failed to open file %s: %v", fullPath, err) + return nil, 0, fmt.Errorf("failed to open file: %w", err) + } + + w.logger.Debugf("Opened file %s (size: %d bytes)", filePath, size) + return reader, size, nil +}