Compare commits

...

8 Commits

29 changed files with 1919 additions and 105 deletions

View File

@@ -1,86 +0,0 @@
# SaveAny-Bot AI 协作说明
本项目是一个将 Telegram 文件/消息转存到多种存储端的 Bot主要用 Go 实现CLI 入口在 `cmd/`,核心逻辑分布在 `client/``core/``config/``database/``storage/` 等目录。下面是针对本仓库的专用约定,供 AI 编码助手参考。
## 总体架构与入口
- **CLI 入口**
- 二进制入口:`main.go` 调用 `cmd.Execute(ctx)`
- 根命令:`cmd/root.go` 使用 `cobra` 定义 `saveany-bot``Run` 实现在 `cmd/run.go`
- **应用启动流程(非常重要)**
- `cmd/run.go::Run` 中按顺序完成:读取配置 `config.Init` → 初始化缓存 `common/cache` → 初始化 i18n `common/i18n` → 初始化数据库 `database.Init` → 加载存储 `storage.LoadStorages` → 加载解析器插件 `parsers.LoadPlugins`可选Userbot 登录 → 启动 Telegram Bot `client/bot.Init` → 启动核心任务队列消费 `core.Run`
- 添加新的初始化步骤时,请遵循该顺序并放在 `initAll` 中,而不是分散在各处。
## 配置与约定
- **配置系统**
- 使用 `viper` 读取 `config.toml`,核心结构体定义在 `config/viper.go::Config`
- 默认值在 `config.Init` 中通过 `viper.SetDefault` 定义,如 `workers``retry``telegram.*``db.*` 等。
- `Config.C()` 返回的是全局配置副本(值类型),不要在返回值上修改字段;如果需要修改配置流程,应在 `config.Init` 内或通过 `viper` 进行。
- 存储配置通过 `config/storage/factory.go::LoadStorageConfigs` 加载并校验,新增存储类型需:
-`pkg/enums/storage` 中增加枚举。
-`config/storage/` 下新增具体 `StorageConfig` 实现并实现 `Validate`
-`storageFactories` 映射中注册工厂方法。
- **环境变量**
- 所有配置键会被 `SAVEANY_` 前缀的环境变量覆盖,`.` 会被 `_` 替换(`SAVEANY_TELEGRAM_APP_ID` 等)。
## Telegram 客户端与中间件
- **Bot 客户端**
- 入口在 `client/bot/bot.go::Init`,使用 `gotgproto.NewClient` 创建Session 使用 `database.GetDialect(config.C().DB.Session)`,错误处理通过 `ErrorHandler` 回调完成。
- Handlers 注册集中在 `client/bot/handlers` 目录,`handlers.Register` 负责统一挂载;新增命令/消息处理逻辑时优先在该目录按功能拆分文件,实现后在 `handlers.Register` 中注册。
- Bot 命令列表依赖 `handlers.CommandHandlers` 来自动注册到 Telegram新增命令时务必更新该切片以保持 `/help` 与 Bot 命令列表一致。
- **中间件**
- 通用中间件位于 `client/middleware/`,包含 floodwait、防崩溃、重试等`middleware.NewDefaultMiddlewares``client/bot/bot.go` 中统一挂载。
- 新增跨所有更新生效的行为(如日志、统计)时,应优先实现为中间件。
## 核心任务与队列
- **任务接口与队列**
- 核心接口:`core/core.go::Executable`,包含 `Type() TaskType``Title()``TaskID()``Execute(ctx)`
- 任务队列:`pkg/queue.TaskQueue[Executable]`,由 `core.Run` 使用;`Workers` 数量来自配置 `config.C().Workers`
- 任务类型与实现示例位于 `core/tasks/**`例如文件任务、Telegraph 任务等;新增任务类型应放在对应子目录并实现 `Executable` 接口,然后通过 `core.AddTask` 入队。
- **生命周期 Hook**
- `core.worker` 在执行任务前后会根据 `config.C().Hook.Exec` 调用外部命令(`TaskBeforeStart` / `TaskSuccess` / `TaskFail` / `TaskCancel`)。
- 修改任务执行流程时需保留这些 Hook 调用,以免破坏用户已有集成。
## 数据库与持久化
- **数据库初始化**
- `database.Init` 使用配置 `config.C().DB.Path` 创建并连接 SQLite使用 `GetDialect` 抽象驱动(见 `database/driver_*.go`)。
- Migration 通过 `db.AutoMigrate(&User{}, &Dir{}, &Rule{}, &WatchChat{})` 完成,模型定义在 `database/*.go` 中。
- **用户同步约定**
- `database.syncUsers` 会根据 `config.C().Users` 同步数据库用户表:在配置中新增/删除用户会自动在 DB 中创建/删除对应记录。
- 开发涉及用户表逻辑时,请考虑该同步行为,避免在其他地方直接创建/删除用户记录而与配置冲突。
## 存储后端
- **存储抽象**
- 抽象接口在 `config/storage/types.go``storage/` 顶层(以及子目录)中;`config/storage/*.go` 处理配置解析,`storage/*` 处理真正的上传/下载实现。
- 现有实现包括 `local``alist``s3/minio``webdav``telegram` 等,每个后端都有对应子目录和配置结构体。
- **新增存储实现的推荐路径**
-`config/storage/` 下添加配置结构体 + `Validate`
-`storage/` 下添加具体实现(例如 `storage/foo/`)。
-`pkg/enums/storage``storageFactories` 中注册,并确保 `storages` 配置示例被更新(`config.example.toml` / 文档)。
## 解析器插件JS
- **插件运行时**
- 解析器接口和插件文档在 `plugins/README.md`Go 端入口为 `parsers/` 目录,使用 `goja``playwright-go`
- 插件通过 `registerParser({ metadata, canHandle, parse })` 注册,使用 `ghttp`/`playwright` 进行 HTTP/浏览器请求。
- **与核心交互约定**
- 插件 `parse` 返回的 `Item`/`Resource` 会被转化为内部任务(通常是下载/转存任务)并进入 `core` 队列。
- 修改 `Item`/`Resource` 结构或解析逻辑时,要确保保持向后兼容,或在 `plugins/README.md` 中同步更新字段说明和示例。
## i18n 与日志
- **国际化**
- 所有用户可见字符串(尤其是错误与提示)应使用 `common/i18n``i18n.T(i18nk.SomeKey, map[string]any{"Name": name})`
- 语言文件位于 `common/i18n/locale/``go:generate` 指令在 `main.go` 中生成 `i18nk/keys.go`;新增文案时需:添加到 YAML、运行 `go generate ./...`、再在代码中引用新 key。
- **日志**
- 使用 `github.com/charmbracelet/log`,在 `cmd/run.go::Run` 中通过 `log.WithContext` 将 logger 注入 `context.Context`;后续代码优先通过 `log.FromContext(ctx)` 获取 logger。
- 编写新代码时,如已有 `ctx`,请使用 `log.FromContext(ctx)` 而不是全局 logger。
## 开发与运行
- **本地运行**
- 直接运行:`go run ./cmd``cmd/root.go` + `cmd/run.go`)。
- 或通过 Docker参见根目录 `README.md` 中的 `docker run ...` 示例及 `docker-compose.yml`
- **代码生成与文档**
- i18n key 生成:`go generate ./...` 会执行 `main.go` 顶部的 `//go:generate`,使用 `cmd/geni18n/main.go` 生成 `common/i18n/i18nk/keys.go`
- 文档站点在 `docs/`Hugo通常不需要在核心代码改动时同步修改除非涉及文档内容。
---
如果你在实现某个功能时发现以上规则不够具体(例如某类任务在 `core/tasks` 中到底如何落地,或某个存储/解析器的边界不清晰),请在对应章节下扩展更精确的说明,并在 PR 描述中标注。也欢迎你告诉我有哪些部分需要补充或澄清,我可以进一步细化。

4
.gitignore vendored
View File

@@ -9,4 +9,6 @@ cache.db
temp/
.hugo_build.lock
playwright/
testplugins/
testplugins/
*.exe
tmp-*

301
AGENTS.md Normal file
View File

@@ -0,0 +1,301 @@
# SaveAny-Bot Agent Guidelines
This document provides essential information for AI coding agents working on the SaveAny-Bot project.
## Project Overview
SaveAny-Bot is a Telegram bot written in Go that saves files/messages from Telegram and various websites to multiple storage backends (local, S3, MinIO, WebDAV, AList, Telegram). It features a plugin system for parsing web content and extensible storage backends.
**Tech Stack**: Go 1.24.2, gotd/td (Telegram MTProto), Cobra (CLI), Viper (config), GORM (ORM), SQLite, Goja (JS runtime), Playwright (browser automation)
## Build & Test Commands
### Build
```bash
# Standard build
go build -o saveany-bot .
# Run directly
go run ./cmd
# Docker build (multi-stage, Alpine-based)
docker build -t saveany-bot .
docker compose up -d
```
### Test
```bash
# Run all tests
go test ./...
# Run tests in specific package
go test ./pkg/queue
go test ./storage/telegram
# Run tests with verbose output
go test -v ./...
# Run a single test
go test -run TestQueueBasic ./pkg/queue
# Run with coverage
go test -cover ./...
```
### Lint & Format
```bash
# Format code (standard Go formatting)
go fmt ./...
# Vet code for common issues
go vet ./...
# Generate code (i18n keys)
go generate ./...
```
### Other Commands
```bash
# Update dependencies
go mod tidy
# View documentation
cd docs && hugo server -D
```
## Code Style Guidelines
### Imports
- Standard library first, then third-party, then project-internal
- Group imports with blank lines between groups
- Use explicit import aliases for clarity when needed (e.g., `storconfig`, `storenum`)
```go
import (
"context"
"fmt"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/storage"
)
```
### Formatting
- Line length: reasonable (no hard limit, but be sensible)
- Organize code with blank lines between logical sections
- Follow standard Go conventions for braces, spacing, etc.
### Types & Interfaces
- Use clear, descriptive type names (PascalCase for exported, camelCase for unexported)
- Define interfaces where abstraction is needed (e.g., `Executable`, `StorageConfig`)
- Embed context in method signatures, not structs: `func (s *Service) Do(ctx context.Context) error`
- Prefer composition over inheritance
```go
// Interfaces define behavior
type Executable interface {
Type() tasktype.TaskType
Title() string
TaskID() string
Execute(ctx context.Context) error
}
// Structs compose behavior
type Local struct {
config config.LocalStorageConfig
logger *log.Logger
}
```
### Naming Conventions
- **Packages**: lowercase, single word when possible (avoid underscores)
- **Files**: lowercase with underscores for multiword (e.g., `auth_terminal.go`, `progress_reader.go`)
- **Variables**: camelCase for unexported, PascalCase for exported
- **Constants**: PascalCase for exported, camelCase for unexported (not ALL_CAPS)
- **Functions/Methods**: PascalCase for exported, camelCase for unexported
- **Test files**: `*_test.go` pattern
### Error Handling
- Always handle errors explicitly; never ignore them
- Wrap errors with context using `fmt.Errorf("context: %w", err)`
- Use `errors.Is()` and `errors.As()` for error checking
- Log errors with appropriate level (Error, Warn, Info)
- Return errors from functions rather than panicking (except for truly unrecoverable situations)
```go
// Good error handling
if err := db.Save(user).Error; err != nil {
return fmt.Errorf("failed to save user %d: %w", user.ChatID, err)
}
// Check specific errors
if errors.Is(err, context.Canceled) {
logger.Info("Operation was canceled")
return nil
}
```
### Logging
- Use `github.com/charmbracelet/log` package
- Get logger from context: `log.FromContext(ctx)`
- Create prefixed loggers for components: `logger.WithPrefix("component")`
- Use appropriate levels: Debug, Info, Warn, Error
- Include context in log messages (e.g., task IDs, file names)
```go
logger := log.FromContext(ctx)
logger.Infof("Processing task: %s", task.ID)
logger.Errorf("Failed to save file %s: %v", filename, err)
```
### Concurrency
- Use channels for communication between goroutines
- Protect shared state with `sync.Mutex` or `sync.RWMutex`
- Use `sync.WaitGroup` for coordinating goroutine completion
- Always pass `context.Context` for cancellation support
- Use `context.WithCancel/WithTimeout` for managing goroutine lifetimes
```go
// Example from queue implementation
func (tq *TaskQueue[T]) Add(task *Task[T]) error {
tq.mu.Lock()
defer tq.mu.Unlock()
// ... critical section
tq.cond.Signal()
return nil
}
```
### Comments
- Document exported types, functions, and packages with doc comments
- Start doc comments with the name being documented
- Use `//` for single-line comments
- Explain *why*, not *what* (code should be self-explanatory for "what")
- Add `[NOTE]`, `[WARN]`, `[IMPORTANT]` tags for important clarifications
```go
// GetUserByChatID retrieves a user by their Telegram chat ID.
// Returns an error if the user is not found.
func GetUserByChatID(ctx context.Context, chatID int64) (*User, error) {
```
## Architecture & Conventions
### Application Structure
- **Entry point**: `main.go``cmd.Execute(ctx)`
- **CLI root**: `cmd/root.go` (Cobra), implementation in `cmd/run.go`
- **Startup sequence**: Config → Cache → i18n → Database → Storage → Parsers → Userbot → Bot → Queue
- Follow this order when adding new initialization steps in `cmd/run.go::initAll`
### Configuration (Viper)
- Config defined in `config/viper.go::Config`
- Read from `config.toml` (see `config.example.toml`)
- Environment variables: `SAVEANY_*` prefix (e.g., `SAVEANY_TELEGRAM_TOKEN`)
- Access via `config.C()` (returns a copy, don't modify the return value)
- Storage configs validated via `config/storage/factory.go::LoadStorageConfigs`
### Telegram Client
- **Bot client**: `client/bot/bot.go::Init` (uses gotgproto)
- **Handlers**: Centralized in `client/bot/handlers/` directory
- **Registration**: All handlers registered in `handlers.Register`
- **Commands**: Add to `CommandHandlers` slice for automatic `/help` and bot command list updates
- **Middleware**: Common middleware in `client/middleware/` (floodwait, retry, etc.)
### Tasks & Queue
- **Task interface**: `core/core.go::Executable` (Type, Title, TaskID, Execute methods)
- **Queue**: `pkg/queue.TaskQueue[Executable]` (generic, thread-safe)
- **Workers**: Count from `config.C().Workers`
- **Task types**: Implementations in `core/tasks/**` (tfile, parsed, telegraph, directlinks, batchtfile)
- **Lifecycle hooks**: `TaskBeforeStart`, `TaskSuccess`, `TaskFail`, `TaskCancel` (defined in config)
- **Adding tasks**: Use `core.AddTask(ctx, task)`
### Database (GORM + SQLite)
- **Init**: `database.Init` using `config.C().DB.Path`
- **Models**: User, Dir, Rule, WatchChat (in `database/*.go`)
- **Migrations**: Automatic via `db.AutoMigrate`
- **User sync**: `database.syncUsers` syncs DB with `config.C().Users` (don't manually create/delete users)
- **Context**: Always use `db.WithContext(ctx)` for operations
### Storage Backends
- **Interface**: Defined in `config/storage/types.go` and `storage/`
- **Implementations**: local, alist, s3/minio, webdav, telegram (each in subdirectory)
- **Adding new storage**:
1. Add enum to `pkg/enums/storage`
2. Create config struct in `config/storage/` with `Validate()` method
3. Implement storage in `storage/<name>/`
4. Register in `storageFactories` mapping
5. Update `config.example.toml` with example
### Parser Plugins (JavaScript)
- **Runtime**: Goja (JS runtime) + Playwright (browser automation)
- **Plugin API**: `registerParser({ metadata, canHandle, parse })` in JS
- **Integration**: Defined in `parsers/` directory
- **Documentation**: See `plugins/README.md`
- Plugin `parse` returns `Item`/`Resource` which becomes download/transfer task
### Internationalization (i18n)
- **Usage**: `i18n.T(i18nk.SomeKey, map[string]any{"Name": value})`
- **Locale files**: `common/i18n/locale/*.yaml`
- **Key generation**: Run `go generate ./...` to generate `common/i18n/i18nk/keys.go`
- **Adding new strings**: Add to YAML → run `go generate` → use in code
- All user-facing strings should be internationalized
### Context Usage
- Always pass `context.Context` as first parameter
- Use `log.FromContext(ctx)` to get contextual logger
- Respect context cancellation in long-running operations
- Store request-scoped data in context (e.g., `ctxkey.ContentLength`)
## Special Rules from .github/copilot-instructions.md
1. **Never modify `config.C()` return values** - it returns a copy. Modify config in `config.Init` or via Viper.
2. **Handlers must update `CommandHandlers` slice** - ensures `/help` and bot commands stay in sync.
3. **Task execution must preserve hooks** - don't remove `TaskBeforeStart`, `TaskSuccess`, `TaskFail`, `TaskCancel` hook calls.
4. **User sync is automatic** - don't manually create/delete users in DB; use config-based sync.
5. **Prefer context logger** - use `log.FromContext(ctx)` over global logger when context is available.
6. **Storage factory pattern** - new storage types must register in `storageFactories` mapping.
7. **Plugin API compatibility** - changes to `Item`/`Resource` structures require updating `plugins/README.md`.
## Common Patterns
### Adding a New Command
1. Create handler function in `client/bot/handlers/<name>.go`
2. Add to `CommandHandlers` slice in `register.go`
3. Add i18n key to `common/i18n/locale/*.yaml`
4. Run `go generate ./...`
5. Test with Telegram bot
### Adding a New Task Type
1. Create struct implementing `core.Executable` in `core/tasks/<type>/`
2. Implement `Type()`, `Title()`, `TaskID()`, `Execute(ctx)` methods
3. Add task type enum to `pkg/enums/tasktype`
4. Use `core.AddTask(ctx, task)` to enqueue
### Adding a New Storage Backend
1. Define config struct in `config/storage/<name>.go` with `Validate()` method
2. Implement storage interface in `storage/<name>/<name>.go`
3. Add storage type enum to `pkg/enums/storage`
4. Register factory in `config/storage/factory.go::storageFactories`
5. Update `config.example.toml` with configuration example
## File References
When referencing code locations, use `path/to/file.go:line` format (e.g., `core/core.go:23` for the worker function).
## Testing Guidelines
- Write tests for new functionality (place in `*_test.go` files)
- Test files should be in same package as code being tested
- Use table-driven tests for multiple test cases
- Mock external dependencies (databases, network calls)
- Aim for meaningful tests, not just coverage numbers
## Notes
- Binary size matters: use `CGO_ENABLED=0` for static binaries
- FFmpeg is included in Docker images for media processing
- Build process supports cross-compilation (amd64/arm64, Linux/macOS/Windows)
- Documentation site uses Hugo; edit files in `docs/` directory
- Session data stored in SQLite; delete `data/session.db` if changing bot token

View File

@@ -90,6 +90,17 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error {
shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID)
case tasktype.TaskTypeDirectlinks:
shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID)
case tasktype.TaskTypeAria2:
client := GetAria2Client()
if client == nil {
ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, i18n.T(i18nk.BotMsgAria2ErrorAria2ClientInitFailed, map[string]any{
"Error": "aria2 client not initialized",
})))
return dispatcher.EndGroups
}
shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID)
case tasktype.TaskTypeYtdlp:
shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, msgID, userID)
default:
return fmt.Errorf("unexcept task type: %s", data.TaskType)
}

View File

@@ -58,6 +58,11 @@ var aria2ClientInitOnce sync.Once
var aria2ClientInitErr error
var aria2Client *aria2.Client
// GetAria2Client returns the shared aria2 client instance
func GetAria2Client() *aria2.Client {
return aria2Client
}
func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
if !config.C().Aria2.Enable {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil)
@@ -78,7 +83,9 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil)
return nil
}
logger.Debug("Adding aria2 download", "links", links)
logger.Debug("Preparing aria2 download", "links", links)
// Initialize aria2 client to check connection
aria2ClientInitOnce.Do(func() {
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)
})
@@ -89,17 +96,18 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
})), nil)
return nil
}
gid, err := aria2Client.AddURI(ctx, links, nil)
// Build storage selection keyboard (don't add to aria2 yet)
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
TaskType: tasktype.TaskTypeAria2,
Aria2URIs: links,
})
if err != nil {
logger.Error("Failed to add aria2 download", "error", err)
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
"Error": err.Error(),
})), nil)
return nil
return err
}
logger.Info("Aria2 download added", "gid", gid)
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoAria2DownloadAdded, map[string]any{
"GID": gid,
})), nil)
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoSelectStorage)), &ext.ReplyOpts{
Markup: markup,
})
return nil
}

View File

@@ -30,6 +30,7 @@ var CommandHandlers = []DescCommandHandler{
{"save", i18nk.BotMsgCmdSave, handleSilentMode(handleSaveCmd, handleSilentSaveReplied)},
{"dl", i18nk.BotMsgCmdDl, handleDlCmd},
{"aria2dl", i18nk.BotMsgCmdAria2dl, handleAria2DlCmd},
{"ytdlp", i18nk.BotMsgCmdYtdlp, handleYtdlpCmd},
{"task", i18nk.BotMsgCmdTask, handleTaskCmd},
{"cancel", i18nk.BotMsgCmdCancel, handleCancelCmd},
{"config", i18nk.BotMsgCmdConfig, handleConfigCmd},

View File

@@ -26,7 +26,7 @@ func BuildDirHelpStyling(dirs []database.Dir) []styling.StyledTextOption {
styling.Blockquote(func() string {
var sb strings.Builder
for _, dir := range dirs {
sb.WriteString(fmt.Sprintf("%d: ", dir.ID))
fmt.Fprintf(&sb, "%d: ", dir.ID)
sb.WriteString(dir.StorageName)
sb.WriteString(" - ")
sb.WriteString(dir.Path)

View File

@@ -49,6 +49,9 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add)
ParsedItem: adddata.ParsedItem,
DirectLinks: adddata.DirectLinks,
Aria2URIs: adddata.Aria2URIs,
YtdlpURLs: adddata.YtdlpURLs,
}
dataid := xid.New().String()
err := cache.Set(dataid, data)

View File

@@ -0,0 +1,65 @@
package shortcut
import (
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
func CreateAndAddAria2TaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, uris []string, aria2Client *aria2.Client, msgID int, userID int64) error {
logger := log.FromContext(ctx)
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
// Now add to aria2 after user selected storage
logger.Infof("Adding download to aria2, uris type: %T, value: %+v", uris, uris)
// Ensure uris is valid
if len(uris) == 0 {
logger.Error("URIs list is empty")
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgDlErrorNoValidLinks, nil),
})
return dispatcher.EndGroups
}
gid, err := aria2Client.AddURI(ctx, uris, nil)
if err != nil {
logger.Errorf("Failed to add aria2 download: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
"Error": err.Error(),
}),
})
return dispatcher.EndGroups
}
logger.Infof("Aria2 download added with GID: %s", gid)
// Create task with the GID
task := aria2dl.NewTask(xid.New().String(), injectCtx, gid, uris, aria2Client, stor, stor.JoinStoragePath(dirPath), aria2dl.NewProgress(msgID, userID))
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),
})
return dispatcher.EndGroups
}
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonInfoTaskAdded, nil),
})
return dispatcher.EndGroups
}

View File

@@ -0,0 +1,62 @@
package shortcut
import (
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
"github.com/rs/xid"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/core/tasks/ytdlp"
"github.com/krau/SaveAny-Bot/storage"
)
func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, msgID int, userID int64) error {
logger := log.FromContext(ctx)
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
// Validate URLs
if len(urls) == 0 {
logger.Error("URLs list is empty")
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls, nil),
})
return dispatcher.EndGroups
}
logger.Infof("Creating yt-dlp task for %d URL(s)", len(urls))
// Create yt-dlp task
task := ytdlp.NewTask(
xid.New().String(),
injectCtx,
urls,
stor,
stor.JoinStoragePath(dirPath),
ytdlp.NewProgress(msgID, userID),
)
// Add task to queue
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add yt-dlp task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),
})
return dispatcher.EndGroups
}
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonInfoTaskAdded, nil),
})
return dispatcher.EndGroups
}

View File

@@ -5,7 +5,9 @@ import (
"path"
"regexp"
"strings"
"sync"
"text/template"
"time"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
@@ -16,10 +18,12 @@ import (
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/core/tasks/tfile"
coretfile "github.com/krau/SaveAny-Bot/core/tasks/tfile"
"github.com/krau/SaveAny-Bot/database"
"github.com/krau/SaveAny-Bot/pkg/enums/fnamest"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
@@ -151,6 +155,53 @@ func handleUnwatchCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
type watchMediaGroupHandler struct {
groups map[int64]map[uint][]tfile.TGFileMessage // chatID -> userID -> files
timers map[int64]map[uint]*time.Timer
mu sync.Mutex
}
var watchMediaGroupMgr = &watchMediaGroupHandler{
groups: make(map[int64]map[uint][]tfile.TGFileMessage),
timers: make(map[int64]map[uint]*time.Timer),
}
func (w *watchMediaGroupHandler) addFile(chatID int64, userID uint, file tfile.TGFileMessage, timeout time.Duration, callback func([]tfile.TGFileMessage)) {
w.mu.Lock()
defer w.mu.Unlock()
if w.groups[chatID] == nil {
w.groups[chatID] = make(map[uint][]tfile.TGFileMessage)
}
if w.timers[chatID] == nil {
w.timers[chatID] = make(map[uint]*time.Timer)
}
if timer, exists := w.timers[chatID][userID]; exists {
timer.Stop()
}
w.groups[chatID][userID] = append(w.groups[chatID][userID], file)
w.timers[chatID][userID] = time.AfterFunc(timeout, func() {
w.mu.Lock()
files := w.groups[chatID][userID]
delete(w.groups[chatID], userID)
delete(w.timers[chatID], userID)
if len(w.groups[chatID]) == 0 {
delete(w.groups, chatID)
}
if len(w.timers[chatID]) == 0 {
delete(w.timers, chatID)
}
w.mu.Unlock()
if len(files) > 0 {
callback(files)
}
})
}
func listenMediaMessageEvent(ch chan userclient.MediaMessageEvent) {
if userclient.GetCtx() == nil {
return
@@ -224,6 +275,24 @@ func listenMediaMessageEvent(ch chan userclient.MediaMessageEvent) {
}
file.SetName(sb.String())
}
// Check if this is a media group and if rules specify NEW-FOR-ALBUM
groupID, isGroup := file.Message().GetGroupedID()
needAlbumHandling := false
if isGroup && groupID != 0 && user.ApplyRule && user.Rules != nil {
_, _, matchedDirPath := ruleutil.ApplyRule(ctx, user.Rules, ruleutil.NewInput(file))
needAlbumHandling = matchedDirPath.NeedNewForAlbum()
}
if needAlbumHandling {
// For media groups with NEW-FOR-ALBUM rule, collect all files of the same group
watchMediaGroupMgr.addFile(event.ChatID, user.ID, file, time.Duration(config.C().Telegram.MediaGroupTimeout)*time.Second, func(files []tfile.TGFileMessage) {
processWatchMediaGroup(ctx, user, stor, "", files)
})
continue
}
// Process single file or media group without album folder creation
var dirPath string
if user.ApplyRule && user.Rules != nil {
matched, matchedStorageName, matchedDirPath := ruleutil.ApplyRule(ctx, user.Rules, ruleutil.NewInput(file))
@@ -243,7 +312,7 @@ func listenMediaMessageEvent(ch chan userclient.MediaMessageEvent) {
storagePath := stor.JoinStoragePath(path.Join(dirPath, file.Name()))
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
taskid := xid.New().String()
task, err := tfile.NewTGFileTask(taskid, injectCtx, file, stor, storagePath, nil)
task, err := coretfile.NewTGFileTask(taskid, injectCtx, file, stor, storagePath, nil)
if err != nil {
logger.Errorf("create task failed: %s", err)
continue
@@ -256,3 +325,97 @@ func listenMediaMessageEvent(ch chan userclient.MediaMessageEvent) {
}
}
}
func processWatchMediaGroup(ctx *ext.Context, user *database.User, stor storage.Storage, dirPath string, files []tfile.TGFileMessage) {
logger := log.FromContext(ctx)
if len(files) == 0 {
return
}
useRule := user.ApplyRule && user.Rules != nil
applyRule := func(file tfile.TGFileMessage) (string, ruleutil.MatchedDirPath) {
if !useRule {
return stor.Name(), ruleutil.MatchedDirPath(dirPath)
}
matched, storName, dirP := ruleutil.ApplyRule(ctx, user.Rules, ruleutil.NewInput(file))
if !matched {
return stor.Name(), ruleutil.MatchedDirPath(dirPath)
}
storname := storName.String()
if !storName.Usable() {
storname = stor.Name()
}
return storname, dirP
}
type albumFile struct {
file tfile.TGFileMessage
storage storage.Storage
}
albumFiles := make(map[int64][]albumFile)
// Collect files by group ID
for _, file := range files {
storName, ruleDirPath := applyRule(file)
fileStor := stor
if storName != stor.Name() && storName != "" {
var err error
fileStor, err = storage.GetStorageByUserIDAndName(ctx, user.ChatID, storName)
if err != nil {
logger.Errorf("Failed to get storage by user ID and name: %s", err)
continue
}
}
groupId, isGroup := file.Message().GetGroupedID()
if !isGroup || groupId == 0 {
logger.Warnf("File %s is not in a group, skipping", file.Name())
continue
}
if !ruleDirPath.NeedNewForAlbum() {
logger.Warnf("File %s does not need album folder, skipping", file.Name())
continue
}
if _, ok := albumFiles[groupId]; !ok {
albumFiles[groupId] = make([]albumFile, 0)
}
albumFiles[groupId] = append(albumFiles[groupId], albumFile{
file: file,
storage: fileStor,
})
}
// Process album files with folder creation
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
totalTasks := 0
for groupID, afiles := range albumFiles {
if len(afiles) <= 1 {
continue
}
// Use first file's name (without extension) as album folder name
albumDir := strings.TrimSuffix(path.Base(afiles[0].file.Name()), path.Ext(afiles[0].file.Name()))
albumStor := afiles[0].storage
logger.Infof("Creating album folder for group %d: %s with %d files", groupID, albumDir, len(afiles))
for _, af := range afiles {
afstorPath := af.storage.JoinStoragePath(path.Join(dirPath, albumDir, af.file.Name()))
taskid := xid.New().String()
task, err := coretfile.NewTGFileTask(taskid, injectCtx, af.file, albumStor, afstorPath, nil)
if err != nil {
logger.Errorf("create task failed for album file: %s", err)
continue
}
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("add task failed: %s", err)
continue
}
totalTasks++
}
}
logger.Infof("Added %d watch media tasks for user %d", totalTasks, user.ChatID)
}

View File

@@ -0,0 +1,63 @@
package handlers
import (
"net/url"
"strings"
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/duke-git/lancet/v2/slice"
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/tcbdata"
"github.com/krau/SaveAny-Bot/storage"
)
func handleYtdlpCmd(ctx *ext.Context, update *ext.Update) error {
logger := log.FromContext(ctx)
args := strings.Split(update.EffectiveMessage.Text, " ")
if len(args) < 2 {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpUsage)), nil)
return dispatcher.EndGroups
}
urls := args[1:]
// Validate and clean URLs
for i, link := range urls {
urls[i] = strings.TrimSpace(link)
u, err := url.Parse(link)
if err != nil || u.Scheme == "" || u.Host == "" {
logger.Warnf("Invalid URL: %s", link)
urls[i] = ""
}
}
urls = slice.Compact(urls)
if len(urls) == 0 {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls)), nil)
return dispatcher.EndGroups
}
logger.Debugf("Preparing yt-dlp download for %d URL(s)", len(urls))
// Build storage selection keyboard
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
TaskType: tasktype.TaskTypeYtdlp,
YtdlpURLs: urls,
})
if err != nil {
return err
}
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpInfoUrlsSelectStorage, map[string]any{
"Count": len(urls),
})), &ext.ReplyOpts{
Markup: markup,
})
return dispatcher.EndGroups
}

View File

@@ -9,6 +9,7 @@ const (
BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled"
BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download"
BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added"
BotMsgAria2InfoSelectStorage Key = "bot.msg.aria2.info_select_storage"
BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed"
BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested"
BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task"
@@ -32,6 +33,7 @@ const (
BotMsgCmdUnwatch Key = "bot.msg.cmd.unwatch"
BotMsgCmdUpdate Key = "bot.msg.cmd.update"
BotMsgCmdWatch Key = "bot.msg.cmd.watch"
BotMsgCmdYtdlp Key = "bot.msg.cmd.ytdlp"
BotMsgCommonCancelButtonText Key = "bot.msg.common.cancel_button_text"
BotMsgCommonErrorBuildDirSelectKeyboardFailed Key = "bot.msg.common.error_build_dir_select_keyboard_failed"
BotMsgCommonErrorBuildStorageSelectKeyboardFailed Key = "bot.msg.common.error_build_storage_select_keyboard_failed"
@@ -127,15 +129,20 @@ const (
BotMsgParserInfoInstallPluginSuccess Key = "bot.msg.parser.info_install_plugin_success"
BotMsgParserPluginNotEnabled Key = "bot.msg.parser.plugin_not_enabled"
BotMsgParserPromptReplyWithParserFile Key = "bot.msg.parser.prompt_reply_with_parser_file"
BotMsgProgressAria2Done Key = "bot.msg.progress.aria2_done"
BotMsgProgressAria2Downloading Key = "bot.msg.progress.aria2_downloading"
BotMsgProgressAria2Start Key = "bot.msg.progress.aria2_start"
BotMsgProgressAvgSpeedPrefix Key = "bot.msg.progress.avg_speed_prefix"
BotMsgProgressBatchDonePrefix Key = "bot.msg.progress.batch_done_prefix"
BotMsgProgressBatchProcessingPrefix Key = "bot.msg.progress.batch_processing_prefix"
BotMsgProgressBatchStartPrefix Key = "bot.msg.progress.batch_start_prefix"
BotMsgProgressCurrentProgressPrefix Key = "bot.msg.progress.current_progress_prefix"
BotMsgProgressCurrentSpeedPrefix Key = "bot.msg.progress.current_speed_prefix"
BotMsgProgressDirectDonePrefix Key = "bot.msg.progress.direct_done_prefix"
BotMsgProgressDirectStart Key = "bot.msg.progress.direct_start"
BotMsgProgressDownloadDonePrefix Key = "bot.msg.progress.download_done_prefix"
BotMsgProgressDownloadFailedPrefix Key = "bot.msg.progress.download_failed_prefix"
BotMsgProgressDownloadedPrefix Key = "bot.msg.progress.downloaded_prefix"
BotMsgProgressDownloadingPrefix Key = "bot.msg.progress.downloading_prefix"
BotMsgProgressErrorPrefix Key = "bot.msg.progress.error_prefix"
BotMsgProgressFileNamePrefix Key = "bot.msg.progress.file_name_prefix"
@@ -154,6 +161,9 @@ const (
BotMsgProgressTelegraphProgressPrefix Key = "bot.msg.progress.telegraph_progress_prefix"
BotMsgProgressTelegraphStartPrefix Key = "bot.msg.progress.telegraph_start_prefix"
BotMsgProgressTotalSizePrefix Key = "bot.msg.progress.total_size_prefix"
BotMsgProgressYtdlpDone Key = "bot.msg.progress.ytdlp_done"
BotMsgProgressYtdlpDownloading Key = "bot.msg.progress.ytdlp_downloading"
BotMsgProgressYtdlpStart Key = "bot.msg.progress.ytdlp_start"
BotMsgRuleErrorCreateRuleFailed Key = "bot.msg.rule.error_create_rule_failed"
BotMsgRuleErrorDeleteRuleFailed Key = "bot.msg.rule.error_delete_rule_failed"
BotMsgRuleErrorGetUserRulesFailed Key = "bot.msg.rule.error_get_user_rules_failed"
@@ -229,6 +239,11 @@ const (
BotMsgWatchInfoWatchListFilterPrefix Key = "bot.msg.watch.info_watch_list_filter_prefix"
BotMsgWatchInfoWatchListHeader Key = "bot.msg.watch.info_watch_list_header"
BotMsgWatchHelpText Key = "bot.msg.watch_help_text"
BotMsgYtdlpErrorDownloadFailed Key = "bot.msg.ytdlp.error_download_failed"
BotMsgYtdlpErrorNoValidUrls Key = "bot.msg.ytdlp.error_no_valid_urls"
BotMsgYtdlpInfoDownloading Key = "bot.msg.ytdlp.info_downloading"
BotMsgYtdlpInfoUrlsSelectStorage Key = "bot.msg.ytdlp.info_urls_select_storage"
BotMsgYtdlpUsage Key = "bot.msg.ytdlp.usage"
ConfigErrDuplicateStorageName Key = "config.err.duplicate_storage_name"
ConfigErrInvalidCacheDir Key = "config.err.invalid_cache_dir"
ErrCleanCacheFailed Key = "err.clean_cache_failed"

View File

@@ -50,6 +50,8 @@ bot:
rule: "Manage auto-save rules"
save: "Save files"
dl: "Download files from given links"
aria2dl: "Download files using Aria2"
ytdlp: "Download video/audio using yt-dlp"
task: "Manage task queue"
cancel: "Cancel task"
watch: "Watch chats (UserBot)"
@@ -286,6 +288,12 @@ bot:
usage: "Usage: /dl <url1> <url2> ..."
error_no_valid_links: "No valid links to download"
info_files_select_storage: "Total {{.Count}} files, please select storage"
ytdlp:
usage: "Usage: /ytdlp <URL1> <URL2> ..."
error_no_valid_urls: "No valid URLs"
info_urls_select_storage: "Found {{.Count}} links, please select storage"
info_downloading: "Downloading via yt-dlp..."
error_download_failed: "yt-dlp download failed: {{.Error}}"
cancel:
usage: "Usage: /cancel <task_id>"
error_cancel_failed: "Failed to cancel task: {{.Error}}"
@@ -326,7 +334,22 @@ bot:
direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)"
file_name_prefix: "Filename: "
error_prefix: "\nError: "
aria2_start: "Waiting for Aria2 to complete download (GID: {{.GID}})..."
aria2_downloading: "Aria2 downloading (GID: {{.GID}})\n"
aria2_done: "Aria2 download completed and transferred (GID: {{.GID}})\n"
ytdlp_start: "Starting yt-dlp download ({{.Count}} links)..."
ytdlp_downloading: "yt-dlp downloading ({{.Count}} links)\n"
ytdlp_done: "yt-dlp download completed and transferred ({{.Count}} files)\n"
downloaded_prefix: "\nDownloaded: "
current_speed_prefix: "\nCurrent speed: "
syncpeers:
start: "Starting to sync peers..."
done: "Peer sync completed, total {{.Count}} chats synced"
failed: "Peer sync failed: {{.Error}}"
aria2:
error_aria2_not_enabled: "Aria2 feature is not enabled in the configuration"
error_aria2_client_init_failed: "Aria2 client initialization failed: {{.Error}}"
info_adding_aria2_download: "Adding Aria2 download task..."
error_adding_aria2_download: "Failed to add Aria2 download task: {{.Error}}"
info_aria2_download_added: "Aria2 download task added, GID: {{.GID}}"
info_select_storage: "Please select storage, the task will be added to Aria2 download queue after selection"

View File

@@ -52,6 +52,7 @@ bot:
save: "保存文件"
dl: "下载给定链接的文件"
aria2dl: "使用 Aria2 下载给定链接的文件"
ytdlp: "使用 yt-dlp 下载视频/音频"
task: "管理任务队列"
cancel: "取消任务"
watch: "监听聊天(UserBot)"
@@ -288,6 +289,12 @@ bot:
usage: "用法: /dl <链接1> <链接2> ..."
error_no_valid_links: "没有有效的链接可供下载"
info_files_select_storage: "共 {{.Count}} 个文件, 请选择存储位置"
ytdlp:
usage: "用法: /ytdlp <URL1> <URL2> ..."
error_no_valid_urls: "没有有效的 URL"
info_urls_select_storage: "共 {{.Count}} 个链接, 请选择存储位置"
info_downloading: "正在通过 yt-dlp 下载..."
error_download_failed: "yt-dlp 下载失败: {{.Error}}"
cancel:
usage: "用法: /cancel <task_id>"
error_cancel_failed: "取消任务失败: {{.Error}}"
@@ -328,6 +335,14 @@ bot:
direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)"
file_name_prefix: "文件名: "
error_prefix: "\n错误: "
aria2_start: "等待 Aria2 下载完成 (GID: {{.GID}})..."
aria2_downloading: "Aria2 正在下载 (GID: {{.GID}})\n"
aria2_done: "Aria2 下载完成并已转存 (GID: {{.GID}})\n"
ytdlp_start: "开始使用 yt-dlp 下载 ({{.Count}} 个链接)..."
ytdlp_downloading: "yt-dlp 正在下载 ({{.Count}} 个链接)\n"
ytdlp_done: "yt-dlp 下载完成并已转存 ({{.Count}} 个文件)\n"
downloaded_prefix: "\n已下载: "
current_speed_prefix: "\n当前速度: "
syncpeers:
start: "正在同步对话列表..."
success: "对话列表同步完成, 共同步 {{.Count}} 个对话"
@@ -338,3 +353,4 @@ bot:
info_adding_aria2_download: "正在添加 Aria2 下载任务..."
error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}"
info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}"
info_select_storage: "请选择存储位置, 选择后将添加到 Aria2 下载队列"

View File

@@ -18,6 +18,17 @@ token = ""
enable = false
url = "socks5://127.0.0.1:7890"
# Aria2 配置
[aria2]
# 启用 Aria2 下载支持
enable = false
# Aria2 RPC URL
url = "http://localhost:6800/jsonrpc"
# Aria2 RPC Secret (如果配置了 rpc-secret)
secret = ""
# 转存完成后删除 Aria2 下载的本地文件
remove_after_transfer = true
# 存储列表
[[storages]]
# 标识名, 需要唯一

View File

@@ -36,9 +36,10 @@ type Config struct {
}
type aria2Config struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
Url string `toml:"url" mapstructure:"url" json:"url"`
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
Url string `toml:"url" mapstructure:"url" json:"url"`
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
KeepFile bool `toml:"keep_file" mapstructure:"keep_file" json:"keep_file"`
}
var cfg = &Config{}

View File

@@ -0,0 +1,250 @@
package aria2dl
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
)
// Execute implements core.Executable.
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx)
logger.Infof("Starting aria2 download task %s (GID: %s)", t.ID, t.GID)
if t.Progress != nil {
t.Progress.OnStart(ctx, t)
}
// Wait for aria2 download to complete
if err := t.waitForDownload(ctx); err != nil {
// If context was canceled, also cancel the aria2 download
if errors.Is(err, context.Canceled) {
t.cancelAria2Download()
}
logger.Errorf("Aria2 download failed: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
// Transfer downloaded files to storage
if err := t.transferFiles(ctx); err != nil {
logger.Errorf("File transfer failed: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
logger.Infof("Aria2 task %s completed successfully", t.ID)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, nil)
}
// Clean up aria2 download result
if _, err := t.Aria2Client.RemoveDownloadResult(context.Background(), t.GID); err != nil {
logger.Warnf("Failed to remove aria2 download result: %v", err)
}
return nil
}
// waitForDownload waits for aria2 to complete the download
func (t *Task) waitForDownload(ctx context.Context) error {
logger := log.FromContext(ctx)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
status, err := t.getStatus(ctx)
if err != nil {
return err
}
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, status)
}
// Check if download is complete
if status.IsDownloadComplete() {
// Handle metadata downloads (torrent/magnet) that spawn follow-up downloads
if len(status.FollowedBy) > 0 {
logger.Infof("Switching from metadata GID %s to actual download GID: %s", t.GID, status.FollowedBy[0])
t.GID = status.FollowedBy[0]
continue
}
logger.Infof("Download completed for GID %s", t.GID)
return nil
}
// Check for errors
if status.IsDownloadError() {
return fmt.Errorf("aria2 download error: %s (code: %s)", status.ErrorMessage, status.ErrorCode)
}
if status.IsDownloadRemoved() {
return errors.New("aria2 download was removed")
}
}
}
}
// getStatus retrieves the current status of the download
func (t *Task) getStatus(ctx context.Context) (*aria2.Status, error) {
logger := log.FromContext(ctx)
// Try active/waiting queue first
status, err := t.Aria2Client.TellStatus(ctx, t.GID)
if err == nil {
return status, nil
}
// Check stopped queue
logger.Debugf("Task not in active queue, checking stopped queue")
stoppedTasks, stopErr := t.Aria2Client.TellStopped(ctx, -1, 100)
if stopErr != nil {
return nil, fmt.Errorf("failed to get aria2 status: %w", err)
}
for _, task := range stoppedTasks {
if task.GID == t.GID {
logger.Debugf("Found task in stopped queue with status: %s", task.Status)
return &task, nil
}
}
return nil, fmt.Errorf("task GID %s not found: %w", t.GID, err)
}
// transferFiles transfers downloaded files from aria2 to storage
func (t *Task) transferFiles(ctx context.Context) error {
logger := log.FromContext(ctx)
status, err := t.Aria2Client.TellStatus(ctx, t.GID)
if err != nil {
return fmt.Errorf("failed to get final status: %w", err)
}
if len(status.Files) == 0 {
return errors.New("no files in aria2 download")
}
logger.Infof("Transferring %d file(s) to storage %s", len(status.Files), t.Storage.Name())
transferredCount := 0
for _, file := range status.Files {
if file.Selected != "true" {
logger.Debugf("Skipping unselected file: %s", file.Path)
continue
}
fileName := filepath.Base(file.Path)
// Skip torrent metadata files
if filepath.Ext(fileName) == ".torrent" {
logger.Debugf("Skipping torrent metadata file: %s", fileName)
t.removeFileIfNeeded(file.Path)
continue
}
if err := t.transferFile(ctx, file.Path); err != nil {
return err
}
transferredCount++
t.removeFileIfNeeded(file.Path)
}
if transferredCount == 0 {
return errors.New("no files were transferred")
}
return nil
}
// transferFile transfers a single file to storage
func (t *Task) transferFile(ctx context.Context, filePath string) error {
logger := log.FromContext(ctx)
// Check if file exists
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
logger.Warnf("Downloaded file not found: %s", filePath)
return nil // Not a fatal error, continue with other files
}
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
}
// Open file
f, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", filePath, err)
}
defer f.Close()
// Set content length in context for storage
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
// Save to storage
fileName := filepath.Base(filePath)
destPath := filepath.Join(t.StorPath, fileName)
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
if err := t.Storage.Save(ctx, f, destPath); err != nil {
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
}
logger.Infof("Successfully transferred file %s", fileName)
return nil
}
// removeFileIfNeeded removes a file if RemoveAfterTransfer is enabled
func (t *Task) removeFileIfNeeded(filePath string) {
if config.C().Aria2.KeepFile {
return
}
logger := log.FromContext(t.ctx)
if err := os.Remove(filePath); err != nil {
logger.Warnf("Failed to remove local file %s: %v", filePath, err)
} else {
logger.Debugf("Removed local file %s", filePath)
}
}
// cancelAria2Download cancels the aria2 download task
func (t *Task) cancelAria2Download() {
logger := log.FromContext(t.ctx)
logger.Infof("Canceling aria2 download GID: %s", t.GID)
// Use a background context with timeout for cleanup
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Try to force remove the download
if _, err := t.Aria2Client.ForceRemove(ctx, t.GID); err != nil {
logger.Warnf("Failed to cancel aria2 download %s: %v", t.GID, err)
} else {
logger.Infof("Successfully canceled aria2 download %s", t.GID)
}
// Also remove the download result to clean up
if _, err := t.Aria2Client.RemoveDownloadResult(ctx, t.GID); err != nil {
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
}
}

View File

@@ -0,0 +1,189 @@
package aria2dl
import (
"context"
"errors"
"fmt"
"strconv"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/pkg/aria2"
)
type ProgressTracker interface {
OnStart(ctx context.Context, task *Task)
OnProgress(ctx context.Context, task *Task, status *aria2.Status)
OnDone(ctx context.Context, task *Task, err error)
}
type Progress struct {
msgID int
chatID int64
start time.Time
lastUpdatePercent atomic.Int32
}
// OnStart implements ProgressTracker.
func (p *Progress) OnStart(ctx context.Context, task *Task) {
logger := log.FromContext(ctx)
p.start = time.Now()
p.lastUpdatePercent.Store(0)
logger.Infof("Aria2 task started: message_id=%d, chat_id=%d, gid=%s", p.msgID, p.chatID, task.GID)
ext := tgutil.ExtFromContext(ctx)
if ext == nil {
return
}
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Start, map[string]any{
"GID": task.GID,
}))); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext.EditMessage(p.chatID, req)
}
// OnProgress implements ProgressTracker.
func (p *Progress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64)
completedLength, _ := strconv.ParseInt(status.CompletedLength, 10, 64)
downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64)
if totalLength == 0 {
return
}
percent := int((completedLength * 100) / totalLength)
if p.lastUpdatePercent.Load() == int32(percent) {
return
}
p.lastUpdatePercent.Store(int32(percent))
log.FromContext(ctx).Debugf("Aria2 progress update: %s, %d/%d", task.GID, completedLength, totalLength)
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Downloading, map[string]any{
"GID": task.GID,
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressDownloadedPrefix, nil)),
styling.Code(fmt.Sprintf("%.2f MB / %.2f MB", float64(completedLength)/(1024*1024), float64(totalLength)/(1024*1024))),
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentSpeedPrefix, nil)),
styling.Bold(fmt.Sprintf("%.2f MB/s", float64(downloadSpeed)/(1024*1024))),
styling.Plain(i18n.T(i18nk.BotMsgProgressAvgSpeedPrefix, nil)),
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(completedLength, p.start)/(1024*1024))),
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentProgressPrefix, nil)),
styling.Bold(fmt.Sprintf("%.2f%%", float64(percent))),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
// OnDone implements ProgressTracker.
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
logger := log.FromContext(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
logger.Infof("Aria2 task %s was canceled", task.TaskID())
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
"TaskID": task.TaskID(),
}),
})
}
} else {
logger.Errorf("Aria2 task %s failed: %s", task.TaskID(), err)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
"Error": err.Error(),
}),
})
}
}
return
}
logger.Infof("Aria2 task %s completed successfully", task.TaskID())
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Done, map[string]any{
"GID": task.GID,
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
); err != nil {
logger.Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
var _ ProgressTracker = (*Progress)(nil)
func NewProgress(msgID int, userID int64) ProgressTracker {
return &Progress{
msgID: msgID,
chatID: userID,
}
}

View File

@@ -0,0 +1,61 @@
package aria2dl
import (
"context"
"fmt"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/storage"
)
var _ core.Executable = (*Task)(nil)
type Task struct {
ID string
ctx context.Context
GID string
URIs []string
Aria2Client *aria2.Client
Storage storage.Storage
StorPath string
Progress ProgressTracker
}
// Title implements core.Executable.
func (t *Task) Title() string {
return fmt.Sprintf("[%s](Aria2 GID:%s->%s:%s)", t.Type(), t.GID, t.Storage.Name(), t.StorPath)
}
// Type implements core.Executable.
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeAria2
}
// TaskID implements core.Executable.
func (t *Task) TaskID() string {
return t.ID
}
func NewTask(
id string,
ctx context.Context,
gid string,
uris []string,
aria2Client *aria2.Client,
stor storage.Storage,
storPath string,
progressTracker ProgressTracker,
) *Task {
return &Task{
ID: id,
ctx: ctx,
GID: gid,
URIs: uris,
Aria2Client: aria2Client,
Storage: stor,
StorPath: storPath,
Progress: progressTracker,
}
}

View File

@@ -0,0 +1,209 @@
package aria2dl
import (
"context"
"io"
"testing"
"time"
storconfig "github.com/krau/SaveAny-Bot/config/storage"
"github.com/krau/SaveAny-Bot/pkg/aria2"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
type mockStorage struct {
name string
savePath string
}
func (m *mockStorage) Name() string {
return m.name
}
func (m *mockStorage) Type() storenum.StorageType {
return storenum.StorageType("mock")
}
func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig) error {
return nil
}
func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error {
m.savePath = path
return nil
}
func (m *mockStorage) Exists(ctx context.Context, path string) bool {
return false
}
func (m *mockStorage) JoinStoragePath(path string) string {
return path
}
type mockProgress struct {
started bool
done bool
doneErr error
progress int
}
func (m *mockProgress) OnStart(ctx context.Context, task *Task) {
m.started = true
}
func (m *mockProgress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
m.progress++
}
func (m *mockProgress) OnDone(ctx context.Context, task *Task, err error) {
m.done = true
m.doneErr = err
}
func TestTaskCreation(t *testing.T) {
ctx := context.Background()
mockStor := &mockStorage{name: "test-storage"}
mockProg := &mockProgress{}
task := NewTask(
"test-task-id",
ctx,
"test-gid",
[]string{"http://example.com/file.zip"},
nil,
mockStor,
"/test/path",
mockProg,
)
if task.ID != "test-task-id" {
t.Errorf("Expected task ID to be 'test-task-id', got '%s'", task.ID)
}
if task.GID != "test-gid" {
t.Errorf("Expected GID to be 'test-gid', got '%s'", task.GID)
}
if task.Type() != tasktype.TaskTypeAria2 {
t.Errorf("Expected task type to be TaskTypeAria2, got '%s'", task.Type())
}
if task.TaskID() != "test-task-id" {
t.Errorf("Expected TaskID() to return 'test-task-id', got '%s'", task.TaskID())
}
if task.Storage.Name() != "test-storage" {
t.Errorf("Expected storage name to be 'test-storage', got '%s'", task.Storage.Name())
}
}
func TestProgressTracker(t *testing.T) {
ctx := context.Background()
mockStor := &mockStorage{name: "test-storage"}
mockProg := &mockProgress{}
task := NewTask(
"test-task-id",
ctx,
"test-gid",
[]string{"http://example.com/file.zip"},
nil,
mockStor,
"/test/path",
mockProg,
)
// Test OnStart
mockProg.OnStart(ctx, task)
if !mockProg.started {
t.Error("Expected OnStart to set started to true")
}
// Test OnProgress
status := &aria2.Status{
GID: "test-gid",
Status: "active",
TotalLength: "1000000",
CompletedLength: "500000",
DownloadSpeed: "100000",
}
mockProg.OnProgress(ctx, task, status)
if mockProg.progress != 1 {
t.Errorf("Expected progress to be 1, got %d", mockProg.progress)
}
// Test OnDone
mockProg.OnDone(ctx, task, nil)
if !mockProg.done {
t.Error("Expected OnDone to set done to true")
}
if mockProg.doneErr != nil {
t.Errorf("Expected doneErr to be nil, got %v", mockProg.doneErr)
}
}
func TestTaskTitle(t *testing.T) {
ctx := context.Background()
mockStor := &mockStorage{name: "test-storage"}
task := NewTask(
"test-task-id",
ctx,
"test-gid-123",
[]string{"http://example.com/file.zip"},
nil,
mockStor,
"/test/path",
nil,
)
title := task.Title()
expectedSubstr := "test-gid-123"
if len(title) == 0 {
t.Error("Expected title to not be empty")
}
// Check if title contains the GID
found := false
for i := 0; i < len(title)-len(expectedSubstr)+1; i++ {
if title[i:i+len(expectedSubstr)] == expectedSubstr {
found = true
break
}
}
if !found {
t.Errorf("Expected title to contain GID '%s', got '%s'", expectedSubstr, title)
}
}
func TestContextCancellation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mockStor := &mockStorage{name: "test-storage"}
mockProg := &mockProgress{}
task := NewTask(
"test-task-id",
ctx,
"test-gid",
[]string{"http://example.com/file.zip"},
nil, // nil client will cause Execute to fail/timeout
mockStor,
"/test/path",
mockProg,
)
// Just verify the task structure is valid
if task.ctx.Err() != nil {
t.Error("Context should not be cancelled yet")
}
// Wait for context to timeout
<-ctx.Done()
if ctx.Err() == nil {
t.Error("Context should be cancelled after timeout")
}
}

182
core/tasks/ytdlp/execute.go Normal file
View File

@@ -0,0 +1,182 @@
package ytdlp
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
ytdlp "github.com/lrstanley/go-ytdlp"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
)
// Execute implements core.Executable.
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx)
logger.Infof("Starting yt-dlp download task %s", t.ID)
if t.Progress != nil {
t.Progress.OnStart(ctx, t)
}
// Create temporary directory for downloads
tempDir, err := os.MkdirTemp(config.C().Temp.BasePath, "ytdlp-*")
if err != nil {
logger.Errorf("Failed to create temp directory: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tempDir) // Clean up temp directory
logger.Debugf("Created temp directory: %s", tempDir)
// Download files using yt-dlp
downloadedFiles, err := t.downloadFiles(ctx, tempDir)
if err != nil {
logger.Errorf("yt-dlp download failed: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
if len(downloadedFiles) == 0 {
err := errors.New("no files were downloaded")
logger.Error(err.Error())
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
// Transfer downloaded files to storage
logger.Infof("Transferring %d file(s) to storage %s", len(downloadedFiles), t.Storage.Name())
for _, filePath := range downloadedFiles {
if err := t.transferFile(ctx, filePath); err != nil {
logger.Errorf("File transfer failed: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
}
logger.Infof("yt-dlp task %s completed successfully", t.ID)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, nil)
}
return nil
}
// downloadFiles downloads files using yt-dlp and returns the list of downloaded file paths
func (t *Task) downloadFiles(ctx context.Context, tempDir string) ([]string, error) {
logger := log.FromContext(ctx)
// Configure yt-dlp command
cmd := ytdlp.New().
FormatSort("res,ext:mp4:m4a").
RecodeVideo("mp4").
Output(filepath.Join(tempDir, "%(title)s.%(ext)s")).
RestrictFilenames()
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, "Downloading...")
}
// Execute download with URLs as arguments
logger.Infof("Executing yt-dlp for %d URL(s)", len(t.URLs))
// Run with context for cancellation support
result, err := cmd.Run(ctx, t.URLs...)
if err != nil {
// Check if context was canceled
if errors.Is(err, context.Canceled) {
return nil, err
}
return nil, fmt.Errorf("yt-dlp execution failed: %w", err)
}
if result.ExitCode != 0 {
return nil, fmt.Errorf("yt-dlp exited with code %d: %s", result.ExitCode, result.Stderr)
}
// List downloaded files
files, err := os.ReadDir(tempDir)
if err != nil {
return nil, fmt.Errorf("failed to read temp directory: %w", err)
}
var downloadedFiles []string
for _, file := range files {
if file.IsDir() {
continue
}
fullPath := filepath.Join(tempDir, file.Name())
downloadedFiles = append(downloadedFiles, fullPath)
logger.Debugf("Downloaded file: %s", file.Name())
}
return downloadedFiles, nil
}
// transferFile transfers a single file to storage
func (t *Task) transferFile(ctx context.Context, filePath string) error {
logger := log.FromContext(ctx)
// Check if file exists
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
logger.Warnf("Downloaded file not found: %s", filePath)
return nil // Not a fatal error
}
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
}
// Open file
f, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", filePath, err)
}
defer f.Close()
// Set content length in context for storage
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
// Save to storage
fileName := filepath.Base(filePath)
// Remove special characters from filename if needed
fileName = sanitizeFilename(fileName)
destPath := filepath.Join(t.StorPath, fileName)
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
if err := t.Storage.Save(ctx, f, destPath); err != nil {
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
}
logger.Infof("Successfully transferred file %s", fileName)
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, fmt.Sprintf("Transferred: %s", fileName))
}
return nil
}
// sanitizeFilename removes or replaces problematic characters in filenames
func sanitizeFilename(name string) string {
// yt-dlp with --restrict-filenames should already handle most cases
// but we can do additional sanitization if needed
name = strings.ReplaceAll(name, ":", "_")
name = strings.ReplaceAll(name, "\"", "'")
return name
}

View File

@@ -0,0 +1,183 @@
package ytdlp
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
)
// ProgressTracker defines the interface for tracking ytdlp task progress
type ProgressTracker interface {
OnStart(ctx context.Context, task *Task)
OnProgress(ctx context.Context, task *Task, status string)
OnDone(ctx context.Context, task *Task, err error)
}
type Progress struct {
msgID int
chatID int64
start time.Time
lastUpdate atomic.Value // stores time.Time
minUpdateInterval time.Duration
}
// OnStart implements ProgressTracker.
func (p *Progress) OnStart(ctx context.Context, task *Task) {
logger := log.FromContext(ctx)
p.start = time.Now()
p.lastUpdate.Store(time.Now())
p.minUpdateInterval = 2 * time.Second // Avoid too frequent updates
logger.Infof("yt-dlp task started: message_id=%d, chat_id=%d, urls=%d", p.msgID, p.chatID, len(task.URLs))
ext := tgutil.ExtFromContext(ctx)
if ext == nil {
return
}
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpStart, map[string]any{
"Count": len(task.URLs),
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext.EditMessage(p.chatID, req)
}
// OnProgress implements ProgressTracker.
func (p *Progress) OnProgress(ctx context.Context, task *Task, status string) {
// Throttle updates to avoid flooding Telegram API
lastUpdateTime := p.lastUpdate.Load().(time.Time)
if time.Since(lastUpdateTime) < p.minUpdateInterval {
return
}
p.lastUpdate.Store(time.Now())
log.FromContext(ctx).Debugf("yt-dlp progress update: %s", status)
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpDownloading, map[string]any{
"Count": len(task.URLs),
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
styling.Plain("\n\n"),
styling.Plain(status),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
// OnDone implements ProgressTracker.
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
logger := log.FromContext(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
logger.Infof("yt-dlp task %s was canceled", task.TaskID())
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
"TaskID": task.TaskID(),
}),
})
}
} else {
logger.Errorf("yt-dlp task %s failed: %s", task.TaskID(), err)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
"Error": err.Error(),
}),
})
}
}
return
}
logger.Infof("yt-dlp task %s completed successfully", task.TaskID())
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpDone, map[string]any{
"Count": len(task.URLs),
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
); err != nil {
logger.Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
var _ ProgressTracker = (*Progress)(nil)
func NewProgress(msgID int, userID int64) ProgressTracker {
return &Progress{
msgID: msgID,
chatID: userID,
minUpdateInterval: 2 * time.Second,
}
}

58
core/tasks/ytdlp/task.go Normal file
View File

@@ -0,0 +1,58 @@
package ytdlp
import (
"context"
"fmt"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/storage"
)
var _ core.Executable = (*Task)(nil)
type Task struct {
ID string
ctx context.Context
URLs []string
Storage storage.Storage
StorPath string
Progress ProgressTracker
}
// Title implements core.Executable.
func (t *Task) Title() string {
urlCount := len(t.URLs)
if urlCount == 1 {
return fmt.Sprintf("[%s](%s->%s:%s)", t.Type(), t.URLs[0], t.Storage.Name(), t.StorPath)
}
return fmt.Sprintf("[%s](%d URLs->%s:%s)", t.Type(), urlCount, t.Storage.Name(), t.StorPath)
}
// Type implements core.Executable.
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeYtdlp
}
// TaskID implements core.Executable.
func (t *Task) TaskID() string {
return t.ID
}
func NewTask(
id string,
ctx context.Context,
urls []string,
stor storage.Storage,
storPath string,
progressTracker ProgressTracker,
) *Task {
return &Task{
ID: id,
ctx: ctx,
URLs: urls,
Storage: stor,
StorPath: storPath,
Progress: progressTracker,
}
}

3
go.mod
View File

@@ -17,6 +17,7 @@ require (
github.com/gotd/td v0.137.0
github.com/johannesboyne/gofakes3 v0.0.0-20250916175020-ebf3e50324d3
github.com/krau/ffmpeg-go v0.6.0
github.com/lrstanley/go-ytdlp v1.2.7
github.com/minio/minio-go/v7 v7.0.98
github.com/playwright-community/playwright-go v0.5200.1
github.com/rs/xid v1.6.0
@@ -31,6 +32,7 @@ require (
require (
github.com/AnimeKaizoku/cacher v1.0.3 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/aws/smithy-go v1.24.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -42,6 +44,7 @@ require (
github.com/clipperhouse/displaywidth v0.7.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/coder/websocket v1.8.14 // indirect
github.com/deckarep/golang-set/v2 v2.8.0 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect

6
go.sum
View File

@@ -4,6 +4,8 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
@@ -66,6 +68,8 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
@@ -176,6 +180,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/krau/ffmpeg-go v0.6.0 h1:F4HWvOrKXQsfLsFTOnUfP0HY6WISJqOrsAFGSIzkKto=
github.com/krau/ffmpeg-go v0.6.0/go.mod h1:sa7/bWHB6fO9j4lhmxnWQ1U07o+dE1leFjhctotxU7A=
github.com/lrstanley/go-ytdlp v1.2.7 h1:YNDvKkd0OCJSZLZePZvJwcirBCfL8Yw3eCwrTCE5w7Q=
github.com/lrstanley/go-ytdlp v1.2.7/go.mod h1:38IL64XM6gULrWtKTiR0+TTNCVbxesNSbTyaFG2CGTI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=

View File

@@ -1,5 +1,5 @@
package tasktype
//go:generate go-enum --values --names --flag --nocase
// ENUM(tgfiles,tphpics,parseditem,directlinks)
// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp)
type TaskType string

View File

@@ -20,6 +20,10 @@ const (
TaskTypeParseditem TaskType = "parseditem"
// TaskTypeDirectlinks is a TaskType of type directlinks.
TaskTypeDirectlinks TaskType = "directlinks"
// TaskTypeAria2 is a TaskType of type aria2.
TaskTypeAria2 TaskType = "aria2"
// TaskTypeYtdlp is a TaskType of type ytdlp.
TaskTypeYtdlp TaskType = "ytdlp"
)
var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", "))
@@ -29,6 +33,8 @@ var _TaskTypeNames = []string{
string(TaskTypeTphpics),
string(TaskTypeParseditem),
string(TaskTypeDirectlinks),
string(TaskTypeAria2),
string(TaskTypeYtdlp),
}
// TaskTypeNames returns a list of possible string values of TaskType.
@@ -45,6 +51,8 @@ func TaskTypeValues() []TaskType {
TaskTypeTphpics,
TaskTypeParseditem,
TaskTypeDirectlinks,
TaskTypeAria2,
TaskTypeYtdlp,
}
}
@@ -65,6 +73,8 @@ var _TaskTypeValue = map[string]TaskType{
"tphpics": TaskTypeTphpics,
"parseditem": TaskTypeParseditem,
"directlinks": TaskTypeDirectlinks,
"aria2": TaskTypeAria2,
"ytdlp": TaskTypeYtdlp,
}
// ParseTaskType attempts to convert a string to a TaskType.

View File

@@ -45,6 +45,10 @@ type Add struct {
ParsedItem *parser.Item
// directlinks
DirectLinks []string
// aria2
Aria2URIs []string
// ytdlp
YtdlpURLs []string
}
type SetDefaultStorage struct {