feat: add yt-dlp support for downloading video/audio and enhance related commands

This commit is contained in:
krau
2026-01-17 17:42:11 +08:00
parent cd7cf4964d
commit 3ce00884a0
16 changed files with 602 additions and 4 deletions

182
core/tasks/ytdlp/execute.go Normal file
View File

@@ -0,0 +1,182 @@
package ytdlp
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
ytdlp "github.com/lrstanley/go-ytdlp"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
)
// Execute implements core.Executable.
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx)
logger.Infof("Starting yt-dlp download task %s", t.ID)
if t.Progress != nil {
t.Progress.OnStart(ctx, t)
}
// Create temporary directory for downloads
tempDir, err := os.MkdirTemp(config.C().Temp.BasePath, "ytdlp-*")
if err != nil {
logger.Errorf("Failed to create temp directory: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tempDir) // Clean up temp directory
logger.Debugf("Created temp directory: %s", tempDir)
// Download files using yt-dlp
downloadedFiles, err := t.downloadFiles(ctx, tempDir)
if err != nil {
logger.Errorf("yt-dlp download failed: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
if len(downloadedFiles) == 0 {
err := errors.New("no files were downloaded")
logger.Error(err.Error())
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
// Transfer downloaded files to storage
logger.Infof("Transferring %d file(s) to storage %s", len(downloadedFiles), t.Storage.Name())
for _, filePath := range downloadedFiles {
if err := t.transferFile(ctx, filePath); err != nil {
logger.Errorf("File transfer failed: %v", err)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, err)
}
return err
}
}
logger.Infof("yt-dlp task %s completed successfully", t.ID)
if t.Progress != nil {
t.Progress.OnDone(ctx, t, nil)
}
return nil
}
// downloadFiles downloads files using yt-dlp and returns the list of downloaded file paths
func (t *Task) downloadFiles(ctx context.Context, tempDir string) ([]string, error) {
logger := log.FromContext(ctx)
// Configure yt-dlp command
cmd := ytdlp.New().
FormatSort("res,ext:mp4:m4a").
RecodeVideo("mp4").
Output(filepath.Join(tempDir, "%(title)s.%(ext)s")).
RestrictFilenames()
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, "Downloading...")
}
// Execute download with URLs as arguments
logger.Infof("Executing yt-dlp for %d URL(s)", len(t.URLs))
// Run with context for cancellation support
result, err := cmd.Run(ctx, t.URLs...)
if err != nil {
// Check if context was canceled
if errors.Is(err, context.Canceled) {
return nil, err
}
return nil, fmt.Errorf("yt-dlp execution failed: %w", err)
}
if result.ExitCode != 0 {
return nil, fmt.Errorf("yt-dlp exited with code %d: %s", result.ExitCode, result.Stderr)
}
// List downloaded files
files, err := os.ReadDir(tempDir)
if err != nil {
return nil, fmt.Errorf("failed to read temp directory: %w", err)
}
var downloadedFiles []string
for _, file := range files {
if file.IsDir() {
continue
}
fullPath := filepath.Join(tempDir, file.Name())
downloadedFiles = append(downloadedFiles, fullPath)
logger.Debugf("Downloaded file: %s", file.Name())
}
return downloadedFiles, nil
}
// transferFile transfers a single file to storage
func (t *Task) transferFile(ctx context.Context, filePath string) error {
logger := log.FromContext(ctx)
// Check if file exists
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
logger.Warnf("Downloaded file not found: %s", filePath)
return nil // Not a fatal error
}
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
}
// Open file
f, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", filePath, err)
}
defer f.Close()
// Set content length in context for storage
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
// Save to storage
fileName := filepath.Base(filePath)
// Remove special characters from filename if needed
fileName = sanitizeFilename(fileName)
destPath := filepath.Join(t.StorPath, fileName)
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
if err := t.Storage.Save(ctx, f, destPath); err != nil {
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
}
logger.Infof("Successfully transferred file %s", fileName)
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, fmt.Sprintf("Transferred: %s", fileName))
}
return nil
}
// sanitizeFilename removes or replaces problematic characters in filenames
func sanitizeFilename(name string) string {
// yt-dlp with --restrict-filenames should already handle most cases
// but we can do additional sanitization if needed
name = strings.ReplaceAll(name, ":", "_")
name = strings.ReplaceAll(name, "\"", "'")
return name
}

View File

@@ -0,0 +1,183 @@
package ytdlp
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
)
// ProgressTracker defines the interface for tracking ytdlp task progress
type ProgressTracker interface {
OnStart(ctx context.Context, task *Task)
OnProgress(ctx context.Context, task *Task, status string)
OnDone(ctx context.Context, task *Task, err error)
}
type Progress struct {
msgID int
chatID int64
start time.Time
lastUpdate atomic.Value // stores time.Time
minUpdateInterval time.Duration
}
// OnStart implements ProgressTracker.
func (p *Progress) OnStart(ctx context.Context, task *Task) {
logger := log.FromContext(ctx)
p.start = time.Now()
p.lastUpdate.Store(time.Now())
p.minUpdateInterval = 2 * time.Second // Avoid too frequent updates
logger.Infof("yt-dlp task started: message_id=%d, chat_id=%d, urls=%d", p.msgID, p.chatID, len(task.URLs))
ext := tgutil.ExtFromContext(ctx)
if ext == nil {
return
}
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpStart, map[string]any{
"Count": len(task.URLs),
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext.EditMessage(p.chatID, req)
}
// OnProgress implements ProgressTracker.
func (p *Progress) OnProgress(ctx context.Context, task *Task, status string) {
// Throttle updates to avoid flooding Telegram API
lastUpdateTime := p.lastUpdate.Load().(time.Time)
if time.Since(lastUpdateTime) < p.minUpdateInterval {
return
}
p.lastUpdate.Store(time.Now())
log.FromContext(ctx).Debugf("yt-dlp progress update: %s", status)
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpDownloading, map[string]any{
"Count": len(task.URLs),
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
styling.Plain("\n\n"),
styling.Plain(status),
); err != nil {
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
Rows: []tg.KeyboardButtonRow{
{
Buttons: []tg.KeyboardButtonClass{
tgutil.BuildCancelButton(task.TaskID()),
},
},
}},
)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
// OnDone implements ProgressTracker.
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
logger := log.FromContext(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
logger.Infof("yt-dlp task %s was canceled", task.TaskID())
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
"TaskID": task.TaskID(),
}),
})
}
} else {
logger.Errorf("yt-dlp task %s failed: %s", task.TaskID(), err)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
ID: p.msgID,
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
"Error": err.Error(),
}),
})
}
}
return
}
logger.Infof("yt-dlp task %s completed successfully", task.TaskID())
entityBuilder := entity.Builder{}
if err := styling.Perform(&entityBuilder,
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpDone, map[string]any{
"Count": len(task.URLs),
})),
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
); err != nil {
logger.Errorf("Failed to build entities: %s", err)
return
}
text, entities := entityBuilder.Complete()
req := &tg.MessagesEditMessageRequest{
ID: p.msgID,
}
req.SetMessage(text)
req.SetEntities(entities)
ext := tgutil.ExtFromContext(ctx)
if ext != nil {
ext.EditMessage(p.chatID, req)
}
}
var _ ProgressTracker = (*Progress)(nil)
func NewProgress(msgID int, userID int64) ProgressTracker {
return &Progress{
msgID: msgID,
chatID: userID,
minUpdateInterval: 2 * time.Second,
}
}

58
core/tasks/ytdlp/task.go Normal file
View File

@@ -0,0 +1,58 @@
package ytdlp
import (
"context"
"fmt"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/storage"
)
var _ core.Executable = (*Task)(nil)
type Task struct {
ID string
ctx context.Context
URLs []string
Storage storage.Storage
StorPath string
Progress ProgressTracker
}
// Title implements core.Executable.
func (t *Task) Title() string {
urlCount := len(t.URLs)
if urlCount == 1 {
return fmt.Sprintf("[%s](%s->%s:%s)", t.Type(), t.URLs[0], t.Storage.Name(), t.StorPath)
}
return fmt.Sprintf("[%s](%d URLs->%s:%s)", t.Type(), urlCount, t.Storage.Name(), t.StorPath)
}
// Type implements core.Executable.
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeYtdlp
}
// TaskID implements core.Executable.
func (t *Task) TaskID() string {
return t.ID
}
func NewTask(
id string,
ctx context.Context,
urls []string,
stor storage.Storage,
storPath string,
progressTracker ProgressTracker,
) *Task {
return &Task{
ID: id,
ctx: ctx,
URLs: urls,
Storage: stor,
StorPath: storPath,
Progress: progressTracker,
}
}