mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-25 01:03:43 +08:00
Compare commits
18 Commits
v0.48.4
...
copilot/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef1ecf960f | ||
|
|
e1b4087801 | ||
|
|
6896bdc852 | ||
|
|
3a6402a71b | ||
|
|
173a5e3733 | ||
|
|
20a5e317ae | ||
|
|
127901fd24 | ||
|
|
30c165033e | ||
|
|
9dcb5201e1 | ||
|
|
7b0142ef82 | ||
|
|
2f6b2470a4 | ||
|
|
ac10c32215 | ||
|
|
7def7f5b28 | ||
|
|
3ce00884a0 | ||
|
|
cd7cf4964d | ||
|
|
bc3c841d1d | ||
|
|
743c15f1a5 | ||
|
|
b05d86509c |
629
api/handlers.go
Normal file
629
api/handlers.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/client/bot"
|
||||
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/ruleutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
tftask "github.com/krau/SaveAny-Bot/core/tasks/tfile"
|
||||
"github.com/krau/SaveAny-Bot/database"
|
||||
"github.com/krau/SaveAny-Bot/pkg/queue"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// Request/Response types
|
||||
type CreateTaskRequest struct {
|
||||
TelegramURL string `json:"telegram_url"`
|
||||
StorageName string `json:"storage_name,omitempty"`
|
||||
DirPath string `json:"dir_path,omitempty"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
type CreateTaskResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type TaskStatusResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"` // queued, running, completed, failed, canceled
|
||||
Title string `json:"title"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Downloaded int64 `json:"downloaded,omitempty"` // Bytes downloaded
|
||||
Total int64 `json:"total,omitempty"` // Total bytes
|
||||
ProgressPct float64 `json:"progress_pct,omitempty"` // Progress percentage (0-100)
|
||||
}
|
||||
|
||||
type ListTasksResponse struct {
|
||||
Queued []TaskInfo `json:"queued"`
|
||||
Running []TaskInfo `json:"running"`
|
||||
}
|
||||
|
||||
type TaskInfo struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// Task tracking
|
||||
var (
|
||||
taskStatuses = make(map[string]*taskStatus)
|
||||
taskStatusesMu sync.RWMutex
|
||||
)
|
||||
|
||||
type taskStatus struct {
|
||||
ID string
|
||||
Status string
|
||||
Title string
|
||||
CreatedAt time.Time
|
||||
Error string
|
||||
Downloaded atomic.Int64 // Use atomic for lock-free updates
|
||||
Total atomic.Int64 // Use atomic for lock-free updates
|
||||
ProgressPct uint64 // Store as uint64 bits of float64 for atomic access
|
||||
}
|
||||
|
||||
func handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func handleCreateTask(w http.ResponseWriter, r *http.Request) {
|
||||
var req CreateTaskRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
respondError(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.TelegramURL == "" {
|
||||
respondError(w, "telegram_url is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.UserID <= 0 {
|
||||
respondError(w, "user_id is required and must be positive", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.FromContext(r.Context()).WithPrefix("api")
|
||||
|
||||
// Get user from database
|
||||
userDB, err := database.GetUserByChatID(r.Context(), req.UserID)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to get user: %v", err)
|
||||
respondError(w, "user not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get storage
|
||||
var stor storage.Storage
|
||||
if req.StorageName != "" {
|
||||
stor, err = storage.GetStorageByUserIDAndName(r.Context(), req.UserID, req.StorageName)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to get storage: %v", err)
|
||||
respondError(w, "storage not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Use first available storage for the user
|
||||
storages := storage.GetUserStorages(r.Context(), req.UserID)
|
||||
if len(storages) == 0 {
|
||||
respondError(w, "no storage available for user", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
stor = storages[0]
|
||||
}
|
||||
|
||||
// Parse Telegram URL
|
||||
botCtx := bot.ExtContext()
|
||||
if botCtx == nil {
|
||||
respondError(w, "bot not initialized", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
linkUrl, err := url.Parse(req.TelegramURL)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to parse URL: %v", err)
|
||||
respondError(w, "invalid telegram URL format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
chatID, msgID, err := tgutil.ParseMessageLink(botCtx, req.TelegramURL)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to parse Telegram URL: %v", err)
|
||||
respondError(w, "invalid telegram URL format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get message from Telegram
|
||||
msg, err := tgutil.GetMessageByID(botCtx, chatID, msgID)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to get message: %v", err)
|
||||
respondError(w, "failed to retrieve message", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if message has media
|
||||
media, ok := msg.GetMedia()
|
||||
if !ok {
|
||||
respondError(w, "message has no media", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Collect files - handle both single and grouped messages
|
||||
files := make([]tfile.TGFileMessage, 0)
|
||||
|
||||
// Check for grouped messages (media group)
|
||||
groupID, isGroup := msg.GetGroupedID()
|
||||
if isGroup && groupID != 0 && !linkUrl.Query().Has("single") {
|
||||
// Handle media group
|
||||
gmsgs, err := tgutil.GetGroupedMessages(botCtx, chatID, msg)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to get grouped messages: %v", err)
|
||||
// Fall back to single message
|
||||
file, err := createTGFileWithMedia(botCtx, msg, media, userDB)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create TGFile: %v", err)
|
||||
respondError(w, "invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
files = append(files, file)
|
||||
} else {
|
||||
// Process all messages in the group
|
||||
for _, gmsg := range gmsgs {
|
||||
if gmsg.Media == nil {
|
||||
continue
|
||||
}
|
||||
gMedia, ok := gmsg.GetMedia()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
file, err := createTGFileWithMedia(botCtx, gmsg, gMedia, userDB)
|
||||
if err != nil {
|
||||
logger.Warnf("Failed to create TGFile for grouped message: %v", err)
|
||||
continue
|
||||
}
|
||||
files = append(files, file)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single message
|
||||
file, err := createTGFileWithMedia(botCtx, msg, media, userDB)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create TGFile: %v", err)
|
||||
respondError(w, "invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
files = append(files, file)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
respondError(w, "no savable files found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create tasks for all files with proper album handling
|
||||
taskIDs := make([]string, 0, len(files))
|
||||
baseDirPath := req.DirPath
|
||||
if baseDirPath == "" {
|
||||
baseDirPath = "/"
|
||||
}
|
||||
|
||||
// Create context with bot extension
|
||||
injectCtx := tgutil.ExtWithContext(r.Context(), botCtx)
|
||||
|
||||
// Apply storage rules if enabled for the user
|
||||
useRule := userDB.ApplyRule && userDB.Rules != nil
|
||||
|
||||
// Helper to apply rules to a file
|
||||
applyRule := func(file tfile.TGFileMessage) (storage.Storage, ruleutil.MatchedDirPath) {
|
||||
fileStor := stor
|
||||
dirPath := ruleutil.MatchedDirPath(baseDirPath)
|
||||
|
||||
if useRule {
|
||||
matched, matchedStorName, matchedDirPath := ruleutil.ApplyRule(injectCtx, userDB.Rules, ruleutil.NewInput(file))
|
||||
if matched {
|
||||
// Rule matched, apply overrides
|
||||
if matchedDirPath != "" {
|
||||
dirPath = matchedDirPath
|
||||
}
|
||||
if matchedStorName.Usable() {
|
||||
var err error
|
||||
fileStor, err = storage.GetStorageByUserIDAndName(injectCtx, userDB.ChatID, matchedStorName.String())
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to get storage from rule: %v", err)
|
||||
// Fall back to original storage
|
||||
fileStor = stor
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return fileStor, dirPath
|
||||
}
|
||||
|
||||
// Separate files into regular and album files
|
||||
type albumFile struct {
|
||||
file tfile.TGFileMessage
|
||||
storage storage.Storage
|
||||
dirPath ruleutil.MatchedDirPath
|
||||
}
|
||||
albumFiles := make(map[int64][]albumFile)
|
||||
|
||||
for _, tgFile := range files {
|
||||
fileStor, dirPath := applyRule(tgFile)
|
||||
|
||||
// Check if this needs album handling (NEW-FOR-ALBUM rule)
|
||||
if dirPath.NeedNewForAlbum() {
|
||||
groupId, isGroup := tgFile.Message().GetGroupedID()
|
||||
if !isGroup || groupId == 0 {
|
||||
logger.Warnf("File %s has NEW-FOR-ALBUM rule but is not in a group, treating as regular file", tgFile.Name())
|
||||
// Treat as regular file with base dir path
|
||||
storagePath := fileStor.JoinStoragePath(path.Join(baseDirPath, tgFile.Name()))
|
||||
taskID := xid.New().String()
|
||||
|
||||
task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, fileStor, storagePath, &apiProgressTracker{
|
||||
taskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create task: %v", err)
|
||||
respondError(w, "failed to create task", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
trackTask(taskID, task.Title(), "queued")
|
||||
if err := core.AddTask(injectCtx, task); err != nil {
|
||||
logger.Errorf("Failed to add task: %v", err)
|
||||
updateTaskStatus(taskID, "failed", err.Error())
|
||||
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
taskIDs = append(taskIDs, taskID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Group by album ID
|
||||
if _, ok := albumFiles[groupId]; !ok {
|
||||
albumFiles[groupId] = make([]albumFile, 0)
|
||||
}
|
||||
albumFiles[groupId] = append(albumFiles[groupId], albumFile{
|
||||
file: tgFile,
|
||||
storage: fileStor,
|
||||
dirPath: dirPath,
|
||||
})
|
||||
} else {
|
||||
// Regular file - create task immediately
|
||||
storagePath := fileStor.JoinStoragePath(path.Join(dirPath.String(), tgFile.Name()))
|
||||
taskID := xid.New().String()
|
||||
|
||||
task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, fileStor, storagePath, &apiProgressTracker{
|
||||
taskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create task: %v", err)
|
||||
respondError(w, "failed to create task", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
trackTask(taskID, task.Title(), "queued")
|
||||
if err := core.AddTask(injectCtx, task); err != nil {
|
||||
logger.Errorf("Failed to add task: %v", err)
|
||||
updateTaskStatus(taskID, "failed", err.Error())
|
||||
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
taskIDs = append(taskIDs, taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle album files - group them into subdirectories
|
||||
for _, afiles := range albumFiles {
|
||||
if len(afiles) <= 1 {
|
||||
// Single file in album, treat as regular
|
||||
for _, af := range afiles {
|
||||
storagePath := af.storage.JoinStoragePath(path.Join(baseDirPath, af.file.Name()))
|
||||
taskID := xid.New().String()
|
||||
|
||||
task, err := tftask.NewTGFileTask(taskID, injectCtx, af.file, af.storage, storagePath, &apiProgressTracker{
|
||||
taskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create task: %v", err)
|
||||
respondError(w, "failed to create task", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
trackTask(taskID, task.Title(), "queued")
|
||||
if err := core.AddTask(injectCtx, task); err != nil {
|
||||
logger.Errorf("Failed to add task: %v", err)
|
||||
updateTaskStatus(taskID, "failed", err.Error())
|
||||
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
taskIDs = append(taskIDs, taskID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Multiple files in album - create subdirectory named after first file
|
||||
// Remove extension from first file's name to use as directory name
|
||||
albumDir := strings.TrimSuffix(path.Base(afiles[0].file.Name()), path.Ext(afiles[0].file.Name()))
|
||||
albumStor := afiles[0].storage
|
||||
|
||||
for _, af := range afiles {
|
||||
// All files go into the album subdirectory
|
||||
storagePath := albumStor.JoinStoragePath(path.Join(baseDirPath, albumDir, af.file.Name()))
|
||||
taskID := xid.New().String()
|
||||
|
||||
task, err := tftask.NewTGFileTask(taskID, injectCtx, af.file, albumStor, storagePath, &apiProgressTracker{
|
||||
taskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create task for album file: %v", err)
|
||||
respondError(w, "failed to create task", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
trackTask(taskID, task.Title(), "queued")
|
||||
if err := core.AddTask(injectCtx, task); err != nil {
|
||||
logger.Errorf("Failed to add task: %v", err)
|
||||
updateTaskStatus(taskID, "failed", err.Error())
|
||||
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
taskIDs = append(taskIDs, taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// Send success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
// Return first task ID for single file, or all task IDs for media group
|
||||
if len(taskIDs) == 1 {
|
||||
json.NewEncoder(w).Encode(CreateTaskResponse{
|
||||
TaskID: taskIDs[0],
|
||||
Message: "task created successfully",
|
||||
})
|
||||
} else {
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"task_ids": taskIDs,
|
||||
"message": fmt.Sprintf("%d tasks created successfully", len(taskIDs)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// createTGFileWithMedia creates a TGFile with proper filename handling using user's strategy
|
||||
func createTGFileWithMedia(botCtx *ext.Context, msg *tg.Message, media tg.MessageMediaClass, userDB *database.User) (tfile.TGFileMessage, error) {
|
||||
// Use the same filename generation logic as bot handlers
|
||||
opts := []tfile.TGFileOption{tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))}
|
||||
return tfile.FromMediaMessage(media, botCtx.Raw, msg, opts...)
|
||||
}
|
||||
|
||||
func handleGetTask(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := r.PathValue("id")
|
||||
if taskID == "" {
|
||||
respondError(w, "task_id is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
taskStatusesMu.RLock()
|
||||
status, exists := taskStatuses[taskID]
|
||||
taskStatusesMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
respondError(w, "task not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TaskStatusResponse{
|
||||
TaskID: status.ID,
|
||||
Status: status.Status,
|
||||
Title: status.Title,
|
||||
CreatedAt: status.CreatedAt,
|
||||
Error: status.Error,
|
||||
Downloaded: status.Downloaded.Load(),
|
||||
Total: status.Total.Load(),
|
||||
ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&status.ProgressPct))),
|
||||
})
|
||||
}
|
||||
|
||||
func handleListTasks(w http.ResponseWriter, r *http.Request) {
|
||||
queued := core.GetQueuedTasks(r.Context())
|
||||
running := core.GetRunningTasks(r.Context())
|
||||
|
||||
response := ListTasksResponse{
|
||||
Queued: convertTaskInfos(queued),
|
||||
Running: convertTaskInfos(running),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func handleCancelTask(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := r.PathValue("id")
|
||||
if taskID == "" {
|
||||
respondError(w, "task_id is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := core.CancelTask(r.Context(), taskID); err != nil {
|
||||
log.FromContext(r.Context()).Errorf("Failed to cancel task %s: %v", taskID, err)
|
||||
respondError(w, "failed to cancel task", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
updateTaskStatus(taskID, "canceled", "")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "task canceled"})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func respondError(w http.ResponseWriter, message string, statusCode int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
json.NewEncoder(w).Encode(ErrorResponse{Error: message})
|
||||
}
|
||||
|
||||
func trackTask(taskID, title, status string) {
|
||||
taskStatusesMu.Lock()
|
||||
defer taskStatusesMu.Unlock()
|
||||
taskStatuses[taskID] = &taskStatus{
|
||||
ID: taskID,
|
||||
Status: status,
|
||||
Title: title,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func updateTaskStatus(taskID, status, errorMsg string) {
|
||||
taskStatusesMu.Lock()
|
||||
defer taskStatusesMu.Unlock()
|
||||
if ts, exists := taskStatuses[taskID]; exists {
|
||||
ts.Status = status
|
||||
ts.Error = errorMsg
|
||||
}
|
||||
}
|
||||
|
||||
func convertTaskInfos(tasks []queue.TaskInfo) []TaskInfo {
|
||||
result := make([]TaskInfo, len(tasks))
|
||||
for i, t := range tasks {
|
||||
result[i] = TaskInfo{
|
||||
ID: t.ID,
|
||||
Title: t.Title,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// apiProgressTracker implements tftask.ProgressTracker for API tasks
|
||||
type apiProgressTracker struct {
|
||||
taskID string
|
||||
}
|
||||
|
||||
func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo) {
|
||||
updateTaskStatus(a.taskID, "running", "")
|
||||
}
|
||||
|
||||
func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) {
|
||||
// Use atomic operations to avoid mutex locks for better performance
|
||||
// OnProgress is called very frequently during downloads
|
||||
taskStatusesMu.RLock()
|
||||
ts, exists := taskStatuses[a.taskID]
|
||||
taskStatusesMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
ts.Downloaded.Store(downloaded)
|
||||
ts.Total.Store(total)
|
||||
if total > 0 {
|
||||
progressPct := float64(downloaded) / float64(total) * 100.0
|
||||
atomic.StoreUint64((*uint64)(&ts.ProgressPct), math.Float64bits(progressPct))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *apiProgressTracker) OnDone(ctx context.Context, info tftask.TaskInfo, err error) {
|
||||
if err != nil {
|
||||
updateTaskStatus(a.taskID, "failed", err.Error())
|
||||
sendWebhook(a.taskID, "failed", err.Error())
|
||||
} else {
|
||||
updateTaskStatus(a.taskID, "completed", "")
|
||||
sendWebhook(a.taskID, "completed", "")
|
||||
}
|
||||
}
|
||||
|
||||
// sendWebhook sends a callback to the configured webhook URL
|
||||
func sendWebhook(taskID, status, errorMsg string) {
|
||||
cfg := config.C()
|
||||
if cfg.API.WebhookURL == "" {
|
||||
return
|
||||
}
|
||||
|
||||
taskStatusesMu.RLock()
|
||||
ts, exists := taskStatuses[taskID]
|
||||
taskStatusesMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithPrefix("webhook")
|
||||
|
||||
payload := TaskStatusResponse{
|
||||
TaskID: ts.ID,
|
||||
Status: status,
|
||||
Title: ts.Title,
|
||||
CreatedAt: ts.CreatedAt,
|
||||
Error: errorMsg,
|
||||
Downloaded: ts.Downloaded.Load(),
|
||||
Total: ts.Total.Load(),
|
||||
ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&ts.ProgressPct))),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to marshal webhook payload: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", cfg.API.WebhookURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create webhook request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if cfg.API.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+cfg.API.Token)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to send webhook: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.Errorf("Webhook returned error status %d, failed to read response body: %v", resp.StatusCode, err)
|
||||
} else {
|
||||
logger.Errorf("Webhook returned error status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
}
|
||||
}
|
||||
66
api/middleware.go
Normal file
66
api/middleware.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
)
|
||||
|
||||
// authMiddleware validates API token
|
||||
func authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip auth for health check
|
||||
if r.URL.Path == "/health" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
cfg := config.C()
|
||||
|
||||
// Check token if configured
|
||||
if cfg.API.Token != "" {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, `{"error":"unauthorized: missing token"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token != cfg.API.Token {
|
||||
http.Error(w, `{"error":"unauthorized: invalid token"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// loggingMiddleware logs HTTP requests
|
||||
func loggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Wrap response writer to capture status code
|
||||
wrapper := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
next.ServeHTTP(wrapper, r)
|
||||
|
||||
logger.Infof("%s %s %d %s", r.Method, r.URL.Path, wrapper.statusCode, time.Since(start))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
47
api/middleware_test.go
Normal file
47
api/middleware_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthMiddleware_NoAuth(t *testing.T) {
|
||||
// Create a test handler
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/api/v1/tasks", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Apply middleware
|
||||
authMiddleware(handler).ServeHTTP(rec, req)
|
||||
|
||||
// When no token is configured, request should succeed or be unauthorized
|
||||
if rec.Code != http.StatusOK && rec.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status 200 or 401, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_HealthCheck(t *testing.T) {
|
||||
// Create a test handler
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
// Create request to health endpoint
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Apply middleware
|
||||
authMiddleware(handler).ServeHTTP(rec, req)
|
||||
|
||||
// Health check should always work without auth
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for health check, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
77
api/server.go
Normal file
77
api/server.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
)
|
||||
|
||||
var server *http.Server
|
||||
|
||||
// Init initializes and starts the HTTP API server
|
||||
func Init(ctx context.Context) error {
|
||||
cfg := config.C()
|
||||
if !cfg.API.Enable {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate that token is configured when API is enabled
|
||||
if cfg.API.Token == "" {
|
||||
return fmt.Errorf("API is enabled but token is not configured. Please set 'api.token' in your configuration file for security")
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx).WithPrefix("api")
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Register API routes
|
||||
registerRoutes(mux)
|
||||
|
||||
// Wrap with middleware
|
||||
handler := loggingMiddleware(logger)(authMiddleware(mux))
|
||||
|
||||
server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.API.Port),
|
||||
Handler: handler,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 15 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.Infof("Starting API server on port %d", cfg.API.Port)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Errorf("API server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Graceful shutdown on context cancellation
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Errorf("Failed to shutdown API server: %v", err)
|
||||
} else {
|
||||
logger.Info("API server stopped")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerRoutes(mux *http.ServeMux) {
|
||||
// Health check endpoint (no auth required)
|
||||
mux.HandleFunc("/health", handleHealth)
|
||||
|
||||
// API v1 endpoints
|
||||
mux.HandleFunc("POST /api/v1/tasks", handleCreateTask)
|
||||
mux.HandleFunc("GET /api/v1/tasks/{id}", handleGetTask)
|
||||
mux.HandleFunc("GET /api/v1/tasks", handleListTasks)
|
||||
mux.HandleFunc("DELETE /api/v1/tasks/{id}", handleCancelTask)
|
||||
}
|
||||
@@ -90,6 +90,17 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error {
|
||||
shortcut.CreateAndAddParsedTaskWithEdit(ctx, selectedStorage, dirPath, data.ParsedItem, msgID, userID)
|
||||
case tasktype.TaskTypeDirectlinks:
|
||||
shortcut.CreateAndAddDirectTaskWithEdit(ctx, selectedStorage, dirPath, data.DirectLinks, msgID, userID)
|
||||
case tasktype.TaskTypeAria2:
|
||||
client := GetAria2Client()
|
||||
if client == nil {
|
||||
ctx.AnswerCallback(msgelem.AlertCallbackAnswer(queryID, i18n.T(i18nk.BotMsgAria2ErrorAria2ClientInitFailed, map[string]any{
|
||||
"Error": "aria2 client not initialized",
|
||||
})))
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID)
|
||||
case tasktype.TaskTypeYtdlp:
|
||||
shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, msgID, userID)
|
||||
default:
|
||||
return fmt.Errorf("unexcept task type: %s", data.TaskType)
|
||||
}
|
||||
|
||||
@@ -58,6 +58,11 @@ var aria2ClientInitOnce sync.Once
|
||||
var aria2ClientInitErr error
|
||||
var aria2Client *aria2.Client
|
||||
|
||||
// GetAria2Client returns the shared aria2 client instance
|
||||
func GetAria2Client() *aria2.Client {
|
||||
return aria2Client
|
||||
}
|
||||
|
||||
func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
if !config.C().Aria2.Enable {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAria2NotEnabled)), nil)
|
||||
@@ -78,7 +83,9 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgDlErrorNoValidLinks)), nil)
|
||||
return nil
|
||||
}
|
||||
logger.Debug("Adding aria2 download", "links", links)
|
||||
logger.Debug("Preparing aria2 download", "links", links)
|
||||
|
||||
// Initialize aria2 client to check connection
|
||||
aria2ClientInitOnce.Do(func() {
|
||||
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)
|
||||
})
|
||||
@@ -89,17 +96,18 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
})), nil)
|
||||
return nil
|
||||
}
|
||||
gid, err := aria2Client.AddURI(ctx, links, nil)
|
||||
|
||||
// Build storage selection keyboard (don't add to aria2 yet)
|
||||
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
|
||||
TaskType: tasktype.TaskTypeAria2,
|
||||
Aria2URIs: links,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to add aria2 download", "error", err)
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
|
||||
"Error": err.Error(),
|
||||
})), nil)
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
logger.Info("Aria2 download added", "gid", gid)
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoAria2DownloadAdded, map[string]any{
|
||||
"GID": gid,
|
||||
})), nil)
|
||||
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgAria2InfoSelectStorage)), &ext.ReplyOpts{
|
||||
Markup: markup,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ var CommandHandlers = []DescCommandHandler{
|
||||
{"save", i18nk.BotMsgCmdSave, handleSilentMode(handleSaveCmd, handleSilentSaveReplied)},
|
||||
{"dl", i18nk.BotMsgCmdDl, handleDlCmd},
|
||||
{"aria2dl", i18nk.BotMsgCmdAria2dl, handleAria2DlCmd},
|
||||
{"ytdlp", i18nk.BotMsgCmdYtdlp, handleYtdlpCmd},
|
||||
{"task", i18nk.BotMsgCmdTask, handleTaskCmd},
|
||||
{"cancel", i18nk.BotMsgCmdCancel, handleCancelCmd},
|
||||
{"config", i18nk.BotMsgCmdConfig, handleConfigCmd},
|
||||
|
||||
@@ -49,6 +49,9 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add)
|
||||
ParsedItem: adddata.ParsedItem,
|
||||
|
||||
DirectLinks: adddata.DirectLinks,
|
||||
|
||||
Aria2URIs: adddata.Aria2URIs,
|
||||
YtdlpURLs: adddata.YtdlpURLs,
|
||||
}
|
||||
dataid := xid.New().String()
|
||||
err := cache.Set(dataid, data)
|
||||
|
||||
65
client/bot/handlers/utils/shortcut/aria2.go
Normal file
65
client/bot/handlers/utils/shortcut/aria2.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package shortcut
|
||||
|
||||
import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
func CreateAndAddAria2TaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, uris []string, aria2Client *aria2.Client, msgID int, userID int64) error {
|
||||
logger := log.FromContext(ctx)
|
||||
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
|
||||
|
||||
// Now add to aria2 after user selected storage
|
||||
logger.Infof("Adding download to aria2, uris type: %T, value: %+v", uris, uris)
|
||||
|
||||
// Ensure uris is valid
|
||||
if len(uris) == 0 {
|
||||
logger.Error("URIs list is empty")
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: msgID,
|
||||
Message: i18n.T(i18nk.BotMsgDlErrorNoValidLinks, nil),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
gid, err := aria2Client.AddURI(ctx, uris, nil)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to add aria2 download: %s", err)
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: msgID,
|
||||
Message: i18n.T(i18nk.BotMsgAria2ErrorAddingAria2Download, map[string]any{
|
||||
"Error": err.Error(),
|
||||
}),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
logger.Infof("Aria2 download added with GID: %s", gid)
|
||||
|
||||
// Create task with the GID
|
||||
task := aria2dl.NewTask(xid.New().String(), injectCtx, gid, uris, aria2Client, stor, stor.JoinStoragePath(dirPath), aria2dl.NewProgress(msgID, userID))
|
||||
if err := core.AddTask(injectCtx, task); err != nil {
|
||||
logger.Errorf("Failed to add task: %s", err)
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: msgID,
|
||||
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
|
||||
"Error": err.Error(),
|
||||
}),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: msgID,
|
||||
Message: i18n.T(i18nk.BotMsgCommonInfoTaskAdded, nil),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
62
client/bot/handlers/utils/shortcut/ytdlp.go
Normal file
62
client/bot/handlers/utils/shortcut/ytdlp.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package shortcut
|
||||
|
||||
import (
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/core/tasks/ytdlp"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, msgID int, userID int64) error {
|
||||
logger := log.FromContext(ctx)
|
||||
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
|
||||
|
||||
// Validate URLs
|
||||
if len(urls) == 0 {
|
||||
logger.Error("URLs list is empty")
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: msgID,
|
||||
Message: i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls, nil),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
logger.Infof("Creating yt-dlp task for %d URL(s)", len(urls))
|
||||
|
||||
// Create yt-dlp task
|
||||
task := ytdlp.NewTask(
|
||||
xid.New().String(),
|
||||
injectCtx,
|
||||
urls,
|
||||
stor,
|
||||
stor.JoinStoragePath(dirPath),
|
||||
ytdlp.NewProgress(msgID, userID),
|
||||
)
|
||||
|
||||
// Add task to queue
|
||||
if err := core.AddTask(injectCtx, task); err != nil {
|
||||
logger.Errorf("Failed to add yt-dlp task: %s", err)
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: msgID,
|
||||
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
|
||||
"Error": err.Error(),
|
||||
}),
|
||||
})
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
|
||||
ID: msgID,
|
||||
Message: i18n.T(i18nk.BotMsgCommonInfoTaskAdded, nil),
|
||||
})
|
||||
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
63
client/bot/handlers/ytdlp.go
Normal file
63
client/bot/handlers/ytdlp.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/celestix/gotgproto/dispatcher"
|
||||
"github.com/celestix/gotgproto/ext"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tcbdata"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
func handleYtdlpCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
logger := log.FromContext(ctx)
|
||||
args := strings.Split(update.EffectiveMessage.Text, " ")
|
||||
if len(args) < 2 {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpUsage)), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
urls := args[1:]
|
||||
// Validate and clean URLs
|
||||
for i, link := range urls {
|
||||
urls[i] = strings.TrimSpace(link)
|
||||
u, err := url.Parse(link)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
logger.Warnf("Invalid URL: %s", link)
|
||||
urls[i] = ""
|
||||
}
|
||||
}
|
||||
urls = slice.Compact(urls)
|
||||
|
||||
if len(urls) == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls)), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
|
||||
logger.Debugf("Preparing yt-dlp download for %d URL(s)", len(urls))
|
||||
|
||||
// Build storage selection keyboard
|
||||
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
|
||||
TaskType: tasktype.TaskTypeYtdlp,
|
||||
YtdlpURLs: urls,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpInfoUrlsSelectStorage, map[string]any{
|
||||
"Count": len(urls),
|
||||
})), &ext.ReplyOpts{
|
||||
Markup: markup,
|
||||
})
|
||||
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/api"
|
||||
"github.com/krau/SaveAny-Bot/client/bot"
|
||||
userclient "github.com/krau/SaveAny-Bot/client/user"
|
||||
"github.com/krau/SaveAny-Bot/common/cache"
|
||||
@@ -76,7 +77,11 @@ func initAll(ctx context.Context, cmd *cobra.Command) (<-chan struct{}, error) {
|
||||
logger.Fatal("User login failed", "error", err)
|
||||
}
|
||||
}
|
||||
return bot.Init(ctx), nil
|
||||
exitChan := bot.Init(ctx)
|
||||
if err := api.Init(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to init API server: %w", err)
|
||||
}
|
||||
return exitChan, nil
|
||||
}
|
||||
|
||||
func cleanCache() {
|
||||
|
||||
@@ -9,6 +9,7 @@ const (
|
||||
BotMsgAria2ErrorAria2NotEnabled Key = "bot.msg.aria2.error_aria2_not_enabled"
|
||||
BotMsgAria2InfoAddingAria2Download Key = "bot.msg.aria2.info_adding_aria2_download"
|
||||
BotMsgAria2InfoAria2DownloadAdded Key = "bot.msg.aria2.info_aria2_download_added"
|
||||
BotMsgAria2InfoSelectStorage Key = "bot.msg.aria2.info_select_storage"
|
||||
BotMsgCancelErrorCancelFailed Key = "bot.msg.cancel.error_cancel_failed"
|
||||
BotMsgCancelInfoCancelRequested Key = "bot.msg.cancel.info_cancel_requested"
|
||||
BotMsgCancelInfoCancellingTask Key = "bot.msg.cancel.info_cancelling_task"
|
||||
@@ -32,6 +33,7 @@ const (
|
||||
BotMsgCmdUnwatch Key = "bot.msg.cmd.unwatch"
|
||||
BotMsgCmdUpdate Key = "bot.msg.cmd.update"
|
||||
BotMsgCmdWatch Key = "bot.msg.cmd.watch"
|
||||
BotMsgCmdYtdlp Key = "bot.msg.cmd.ytdlp"
|
||||
BotMsgCommonCancelButtonText Key = "bot.msg.common.cancel_button_text"
|
||||
BotMsgCommonErrorBuildDirSelectKeyboardFailed Key = "bot.msg.common.error_build_dir_select_keyboard_failed"
|
||||
BotMsgCommonErrorBuildStorageSelectKeyboardFailed Key = "bot.msg.common.error_build_storage_select_keyboard_failed"
|
||||
@@ -127,15 +129,20 @@ const (
|
||||
BotMsgParserInfoInstallPluginSuccess Key = "bot.msg.parser.info_install_plugin_success"
|
||||
BotMsgParserPluginNotEnabled Key = "bot.msg.parser.plugin_not_enabled"
|
||||
BotMsgParserPromptReplyWithParserFile Key = "bot.msg.parser.prompt_reply_with_parser_file"
|
||||
BotMsgProgressAria2Done Key = "bot.msg.progress.aria2_done"
|
||||
BotMsgProgressAria2Downloading Key = "bot.msg.progress.aria2_downloading"
|
||||
BotMsgProgressAria2Start Key = "bot.msg.progress.aria2_start"
|
||||
BotMsgProgressAvgSpeedPrefix Key = "bot.msg.progress.avg_speed_prefix"
|
||||
BotMsgProgressBatchDonePrefix Key = "bot.msg.progress.batch_done_prefix"
|
||||
BotMsgProgressBatchProcessingPrefix Key = "bot.msg.progress.batch_processing_prefix"
|
||||
BotMsgProgressBatchStartPrefix Key = "bot.msg.progress.batch_start_prefix"
|
||||
BotMsgProgressCurrentProgressPrefix Key = "bot.msg.progress.current_progress_prefix"
|
||||
BotMsgProgressCurrentSpeedPrefix Key = "bot.msg.progress.current_speed_prefix"
|
||||
BotMsgProgressDirectDonePrefix Key = "bot.msg.progress.direct_done_prefix"
|
||||
BotMsgProgressDirectStart Key = "bot.msg.progress.direct_start"
|
||||
BotMsgProgressDownloadDonePrefix Key = "bot.msg.progress.download_done_prefix"
|
||||
BotMsgProgressDownloadFailedPrefix Key = "bot.msg.progress.download_failed_prefix"
|
||||
BotMsgProgressDownloadedPrefix Key = "bot.msg.progress.downloaded_prefix"
|
||||
BotMsgProgressDownloadingPrefix Key = "bot.msg.progress.downloading_prefix"
|
||||
BotMsgProgressErrorPrefix Key = "bot.msg.progress.error_prefix"
|
||||
BotMsgProgressFileNamePrefix Key = "bot.msg.progress.file_name_prefix"
|
||||
@@ -154,6 +161,9 @@ const (
|
||||
BotMsgProgressTelegraphProgressPrefix Key = "bot.msg.progress.telegraph_progress_prefix"
|
||||
BotMsgProgressTelegraphStartPrefix Key = "bot.msg.progress.telegraph_start_prefix"
|
||||
BotMsgProgressTotalSizePrefix Key = "bot.msg.progress.total_size_prefix"
|
||||
BotMsgProgressYtdlpDone Key = "bot.msg.progress.ytdlp_done"
|
||||
BotMsgProgressYtdlpDownloading Key = "bot.msg.progress.ytdlp_downloading"
|
||||
BotMsgProgressYtdlpStart Key = "bot.msg.progress.ytdlp_start"
|
||||
BotMsgRuleErrorCreateRuleFailed Key = "bot.msg.rule.error_create_rule_failed"
|
||||
BotMsgRuleErrorDeleteRuleFailed Key = "bot.msg.rule.error_delete_rule_failed"
|
||||
BotMsgRuleErrorGetUserRulesFailed Key = "bot.msg.rule.error_get_user_rules_failed"
|
||||
@@ -229,6 +239,11 @@ const (
|
||||
BotMsgWatchInfoWatchListFilterPrefix Key = "bot.msg.watch.info_watch_list_filter_prefix"
|
||||
BotMsgWatchInfoWatchListHeader Key = "bot.msg.watch.info_watch_list_header"
|
||||
BotMsgWatchHelpText Key = "bot.msg.watch_help_text"
|
||||
BotMsgYtdlpErrorDownloadFailed Key = "bot.msg.ytdlp.error_download_failed"
|
||||
BotMsgYtdlpErrorNoValidUrls Key = "bot.msg.ytdlp.error_no_valid_urls"
|
||||
BotMsgYtdlpInfoDownloading Key = "bot.msg.ytdlp.info_downloading"
|
||||
BotMsgYtdlpInfoUrlsSelectStorage Key = "bot.msg.ytdlp.info_urls_select_storage"
|
||||
BotMsgYtdlpUsage Key = "bot.msg.ytdlp.usage"
|
||||
ConfigErrDuplicateStorageName Key = "config.err.duplicate_storage_name"
|
||||
ConfigErrInvalidCacheDir Key = "config.err.invalid_cache_dir"
|
||||
ErrCleanCacheFailed Key = "err.clean_cache_failed"
|
||||
|
||||
@@ -50,6 +50,8 @@ bot:
|
||||
rule: "Manage auto-save rules"
|
||||
save: "Save files"
|
||||
dl: "Download files from given links"
|
||||
aria2dl: "Download files using Aria2"
|
||||
ytdlp: "Download video/audio using yt-dlp"
|
||||
task: "Manage task queue"
|
||||
cancel: "Cancel task"
|
||||
watch: "Watch chats (UserBot)"
|
||||
@@ -286,6 +288,12 @@ bot:
|
||||
usage: "Usage: /dl <url1> <url2> ..."
|
||||
error_no_valid_links: "No valid links to download"
|
||||
info_files_select_storage: "Total {{.Count}} files, please select storage"
|
||||
ytdlp:
|
||||
usage: "Usage: /ytdlp <URL1> <URL2> ..."
|
||||
error_no_valid_urls: "No valid URLs"
|
||||
info_urls_select_storage: "Found {{.Count}} links, please select storage"
|
||||
info_downloading: "Downloading via yt-dlp..."
|
||||
error_download_failed: "yt-dlp download failed: {{.Error}}"
|
||||
cancel:
|
||||
usage: "Usage: /cancel <task_id>"
|
||||
error_cancel_failed: "Failed to cancel task: {{.Error}}"
|
||||
@@ -326,7 +334,22 @@ bot:
|
||||
direct_start: "Starting download, total size: {{.SizeMB}} MB ({{.Count}} files)"
|
||||
file_name_prefix: "Filename: "
|
||||
error_prefix: "\nError: "
|
||||
aria2_start: "Waiting for Aria2 to complete download (GID: {{.GID}})..."
|
||||
aria2_downloading: "Aria2 downloading (GID: {{.GID}})\n"
|
||||
aria2_done: "Aria2 download completed and transferred (GID: {{.GID}})\n"
|
||||
ytdlp_start: "Starting yt-dlp download ({{.Count}} links)..."
|
||||
ytdlp_downloading: "yt-dlp downloading ({{.Count}} links)\n"
|
||||
ytdlp_done: "yt-dlp download completed and transferred ({{.Count}} files)\n"
|
||||
downloaded_prefix: "\nDownloaded: "
|
||||
current_speed_prefix: "\nCurrent speed: "
|
||||
syncpeers:
|
||||
start: "Starting to sync peers..."
|
||||
done: "Peer sync completed, total {{.Count}} chats synced"
|
||||
failed: "Peer sync failed: {{.Error}}"
|
||||
aria2:
|
||||
error_aria2_not_enabled: "Aria2 feature is not enabled in the configuration"
|
||||
error_aria2_client_init_failed: "Aria2 client initialization failed: {{.Error}}"
|
||||
info_adding_aria2_download: "Adding Aria2 download task..."
|
||||
error_adding_aria2_download: "Failed to add Aria2 download task: {{.Error}}"
|
||||
info_aria2_download_added: "Aria2 download task added, GID: {{.GID}}"
|
||||
info_select_storage: "Please select storage, the task will be added to Aria2 download queue after selection"
|
||||
|
||||
@@ -52,6 +52,7 @@ bot:
|
||||
save: "保存文件"
|
||||
dl: "下载给定链接的文件"
|
||||
aria2dl: "使用 Aria2 下载给定链接的文件"
|
||||
ytdlp: "使用 yt-dlp 下载视频/音频"
|
||||
task: "管理任务队列"
|
||||
cancel: "取消任务"
|
||||
watch: "监听聊天(UserBot)"
|
||||
@@ -288,6 +289,12 @@ bot:
|
||||
usage: "用法: /dl <链接1> <链接2> ..."
|
||||
error_no_valid_links: "没有有效的链接可供下载"
|
||||
info_files_select_storage: "共 {{.Count}} 个文件, 请选择存储位置"
|
||||
ytdlp:
|
||||
usage: "用法: /ytdlp <URL1> <URL2> ..."
|
||||
error_no_valid_urls: "没有有效的 URL"
|
||||
info_urls_select_storage: "共 {{.Count}} 个链接, 请选择存储位置"
|
||||
info_downloading: "正在通过 yt-dlp 下载..."
|
||||
error_download_failed: "yt-dlp 下载失败: {{.Error}}"
|
||||
cancel:
|
||||
usage: "用法: /cancel <task_id>"
|
||||
error_cancel_failed: "取消任务失败: {{.Error}}"
|
||||
@@ -328,6 +335,14 @@ bot:
|
||||
direct_start: "开始下载, 总大小: {{.SizeMB}} MB ({{.Count}} 个文件)"
|
||||
file_name_prefix: "文件名: "
|
||||
error_prefix: "\n错误: "
|
||||
aria2_start: "等待 Aria2 下载完成 (GID: {{.GID}})..."
|
||||
aria2_downloading: "Aria2 正在下载 (GID: {{.GID}})\n"
|
||||
aria2_done: "Aria2 下载完成并已转存 (GID: {{.GID}})\n"
|
||||
ytdlp_start: "开始使用 yt-dlp 下载 ({{.Count}} 个链接)..."
|
||||
ytdlp_downloading: "yt-dlp 正在下载 ({{.Count}} 个链接)\n"
|
||||
ytdlp_done: "yt-dlp 下载完成并已转存 ({{.Count}} 个文件)\n"
|
||||
downloaded_prefix: "\n已下载: "
|
||||
current_speed_prefix: "\n当前速度: "
|
||||
syncpeers:
|
||||
start: "正在同步对话列表..."
|
||||
success: "对话列表同步完成, 共同步 {{.Count}} 个对话"
|
||||
@@ -338,3 +353,4 @@ bot:
|
||||
info_adding_aria2_download: "正在添加 Aria2 下载任务..."
|
||||
error_adding_aria2_download: "添加 Aria2 下载任务失败: {{.Error}}"
|
||||
info_aria2_download_added: "Aria2 下载任务已添加, GID: {{.GID}}"
|
||||
info_select_storage: "请选择存储位置, 选择后将添加到 Aria2 下载队列"
|
||||
@@ -18,6 +18,28 @@ token = ""
|
||||
enable = false
|
||||
url = "socks5://127.0.0.1:7890"
|
||||
|
||||
# Aria2 配置
|
||||
[aria2]
|
||||
# 启用 Aria2 下载支持
|
||||
enable = false
|
||||
# Aria2 RPC URL
|
||||
url = "http://localhost:6800/jsonrpc"
|
||||
# Aria2 RPC Secret (如果配置了 rpc-secret)
|
||||
secret = ""
|
||||
# 转存完成后删除 Aria2 下载的本地文件
|
||||
remove_after_transfer = true
|
||||
|
||||
# HTTP API 配置
|
||||
[api]
|
||||
# 启用 HTTP API 服务
|
||||
enable = false
|
||||
# API 服务监听端口
|
||||
port = 8080
|
||||
# API 访问令牌 (留空则不验证)
|
||||
token = ""
|
||||
# 任务完成回调 Webhook URL (留空则不回调)
|
||||
webhook_url = ""
|
||||
|
||||
# 存储列表
|
||||
[[storages]]
|
||||
# 标识名, 需要唯一
|
||||
|
||||
@@ -33,12 +33,21 @@ type Config struct {
|
||||
Storages []storage.StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
|
||||
Parser parserConfig `toml:"parser" mapstructure:"parser" json:"parser"`
|
||||
Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"`
|
||||
API apiConfig `toml:"api" mapstructure:"api" json:"api"`
|
||||
}
|
||||
|
||||
type aria2Config struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
Url string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
Url string `toml:"url" mapstructure:"url" json:"url"`
|
||||
Secret string `toml:"secret" mapstructure:"secret" json:"secret"`
|
||||
KeepFile bool `toml:"keep_file" mapstructure:"keep_file" json:"keep_file"`
|
||||
}
|
||||
|
||||
type apiConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
Port int `toml:"port" mapstructure:"port" json:"port"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
WebhookURL string `toml:"webhook_url" mapstructure:"webhook_url" json:"webhook_url"`
|
||||
}
|
||||
|
||||
var cfg = &Config{}
|
||||
@@ -114,6 +123,12 @@ func Init(ctx context.Context, configFile ...string) error {
|
||||
// 数据库
|
||||
"db.path": "data/saveany.db",
|
||||
"db.session": "data/session.db",
|
||||
|
||||
// API
|
||||
"api.enable": false,
|
||||
"api.port": 8080,
|
||||
"api.token": "",
|
||||
"api.webhook_url": "",
|
||||
}
|
||||
|
||||
for key, value := range defaultConfigs {
|
||||
|
||||
250
core/tasks/aria2dl/execute.go
Normal file
250
core/tasks/aria2dl/execute.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
)
|
||||
|
||||
// Execute implements core.Executable.
|
||||
func (t *Task) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Infof("Starting aria2 download task %s (GID: %s)", t.ID, t.GID)
|
||||
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnStart(ctx, t)
|
||||
}
|
||||
|
||||
// Wait for aria2 download to complete
|
||||
if err := t.waitForDownload(ctx); err != nil {
|
||||
// If context was canceled, also cancel the aria2 download
|
||||
if errors.Is(err, context.Canceled) {
|
||||
t.cancelAria2Download()
|
||||
}
|
||||
logger.Errorf("Aria2 download failed: %v", err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Transfer downloaded files to storage
|
||||
if err := t.transferFiles(ctx); err != nil {
|
||||
logger.Errorf("File transfer failed: %v", err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("Aria2 task %s completed successfully", t.ID)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, nil)
|
||||
}
|
||||
|
||||
// Clean up aria2 download result
|
||||
if _, err := t.Aria2Client.RemoveDownloadResult(context.Background(), t.GID); err != nil {
|
||||
logger.Warnf("Failed to remove aria2 download result: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForDownload waits for aria2 to complete the download
|
||||
func (t *Task) waitForDownload(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx)
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
status, err := t.getStatus(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t, status)
|
||||
}
|
||||
|
||||
// Check if download is complete
|
||||
if status.IsDownloadComplete() {
|
||||
// Handle metadata downloads (torrent/magnet) that spawn follow-up downloads
|
||||
if len(status.FollowedBy) > 0 {
|
||||
logger.Infof("Switching from metadata GID %s to actual download GID: %s", t.GID, status.FollowedBy[0])
|
||||
t.GID = status.FollowedBy[0]
|
||||
continue
|
||||
}
|
||||
logger.Infof("Download completed for GID %s", t.GID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
if status.IsDownloadError() {
|
||||
return fmt.Errorf("aria2 download error: %s (code: %s)", status.ErrorMessage, status.ErrorCode)
|
||||
}
|
||||
|
||||
if status.IsDownloadRemoved() {
|
||||
return errors.New("aria2 download was removed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getStatus retrieves the current status of the download
|
||||
func (t *Task) getStatus(ctx context.Context) (*aria2.Status, error) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
// Try active/waiting queue first
|
||||
status, err := t.Aria2Client.TellStatus(ctx, t.GID)
|
||||
if err == nil {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// Check stopped queue
|
||||
logger.Debugf("Task not in active queue, checking stopped queue")
|
||||
stoppedTasks, stopErr := t.Aria2Client.TellStopped(ctx, -1, 100)
|
||||
if stopErr != nil {
|
||||
return nil, fmt.Errorf("failed to get aria2 status: %w", err)
|
||||
}
|
||||
|
||||
for _, task := range stoppedTasks {
|
||||
if task.GID == t.GID {
|
||||
logger.Debugf("Found task in stopped queue with status: %s", task.Status)
|
||||
return &task, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("task GID %s not found: %w", t.GID, err)
|
||||
}
|
||||
|
||||
// transferFiles transfers downloaded files from aria2 to storage
|
||||
func (t *Task) transferFiles(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
status, err := t.Aria2Client.TellStatus(ctx, t.GID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get final status: %w", err)
|
||||
}
|
||||
|
||||
if len(status.Files) == 0 {
|
||||
return errors.New("no files in aria2 download")
|
||||
}
|
||||
|
||||
logger.Infof("Transferring %d file(s) to storage %s", len(status.Files), t.Storage.Name())
|
||||
transferredCount := 0
|
||||
|
||||
for _, file := range status.Files {
|
||||
if file.Selected != "true" {
|
||||
logger.Debugf("Skipping unselected file: %s", file.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
fileName := filepath.Base(file.Path)
|
||||
|
||||
// Skip torrent metadata files
|
||||
if filepath.Ext(fileName) == ".torrent" {
|
||||
logger.Debugf("Skipping torrent metadata file: %s", fileName)
|
||||
t.removeFileIfNeeded(file.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := t.transferFile(ctx, file.Path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
transferredCount++
|
||||
t.removeFileIfNeeded(file.Path)
|
||||
}
|
||||
|
||||
if transferredCount == 0 {
|
||||
return errors.New("no files were transferred")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// transferFile transfers a single file to storage
|
||||
func (t *Task) transferFile(ctx context.Context, filePath string) error {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
// Check if file exists
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
logger.Warnf("Downloaded file not found: %s", filePath)
|
||||
return nil // Not a fatal error, continue with other files
|
||||
}
|
||||
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Open file
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file %s: %w", filePath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Set content length in context for storage
|
||||
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
|
||||
|
||||
// Save to storage
|
||||
fileName := filepath.Base(filePath)
|
||||
destPath := filepath.Join(t.StorPath, fileName)
|
||||
|
||||
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
|
||||
|
||||
if err := t.Storage.Save(ctx, f, destPath); err != nil {
|
||||
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
|
||||
}
|
||||
|
||||
logger.Infof("Successfully transferred file %s", fileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeFileIfNeeded removes a file if RemoveAfterTransfer is enabled
|
||||
func (t *Task) removeFileIfNeeded(filePath string) {
|
||||
if config.C().Aria2.KeepFile {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.FromContext(t.ctx)
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
logger.Warnf("Failed to remove local file %s: %v", filePath, err)
|
||||
} else {
|
||||
logger.Debugf("Removed local file %s", filePath)
|
||||
}
|
||||
}
|
||||
|
||||
// cancelAria2Download cancels the aria2 download task
|
||||
func (t *Task) cancelAria2Download() {
|
||||
logger := log.FromContext(t.ctx)
|
||||
logger.Infof("Canceling aria2 download GID: %s", t.GID)
|
||||
|
||||
// Use a background context with timeout for cleanup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to force remove the download
|
||||
if _, err := t.Aria2Client.ForceRemove(ctx, t.GID); err != nil {
|
||||
logger.Warnf("Failed to cancel aria2 download %s: %v", t.GID, err)
|
||||
} else {
|
||||
logger.Infof("Successfully canceled aria2 download %s", t.GID)
|
||||
}
|
||||
|
||||
// Also remove the download result to clean up
|
||||
if _, err := t.Aria2Client.RemoveDownloadResult(ctx, t.GID); err != nil {
|
||||
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
|
||||
}
|
||||
}
|
||||
189
core/tasks/aria2dl/progress.go
Normal file
189
core/tasks/aria2dl/progress.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/dlutil"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
)
|
||||
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, task *Task)
|
||||
OnProgress(ctx context.Context, task *Task, status *aria2.Status)
|
||||
OnDone(ctx context.Context, task *Task, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
msgID int
|
||||
chatID int64
|
||||
start time.Time
|
||||
lastUpdatePercent atomic.Int32
|
||||
}
|
||||
|
||||
// OnStart implements ProgressTracker.
|
||||
func (p *Progress) OnStart(ctx context.Context, task *Task) {
|
||||
logger := log.FromContext(ctx)
|
||||
p.start = time.Now()
|
||||
p.lastUpdatePercent.Store(0)
|
||||
logger.Infof("Aria2 task started: message_id=%d, chat_id=%d, gid=%s", p.msgID, p.chatID, task.GID)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext == nil {
|
||||
return
|
||||
}
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Start, map[string]any{
|
||||
"GID": task.GID,
|
||||
}))); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(task.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
|
||||
// OnProgress implements ProgressTracker.
|
||||
func (p *Progress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
|
||||
totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64)
|
||||
completedLength, _ := strconv.ParseInt(status.CompletedLength, 10, 64)
|
||||
downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64)
|
||||
|
||||
if totalLength == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
percent := int((completedLength * 100) / totalLength)
|
||||
if p.lastUpdatePercent.Load() == int32(percent) {
|
||||
return
|
||||
}
|
||||
p.lastUpdatePercent.Store(int32(percent))
|
||||
|
||||
log.FromContext(ctx).Debugf("Aria2 progress update: %s, %d/%d", task.GID, completedLength, totalLength)
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Downloading, map[string]any{
|
||||
"GID": task.GID,
|
||||
})),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressDownloadedPrefix, nil)),
|
||||
styling.Code(fmt.Sprintf("%.2f MB / %.2f MB", float64(completedLength)/(1024*1024), float64(totalLength)/(1024*1024))),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentSpeedPrefix, nil)),
|
||||
styling.Bold(fmt.Sprintf("%.2f MB/s", float64(downloadSpeed)/(1024*1024))),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAvgSpeedPrefix, nil)),
|
||||
styling.Bold(fmt.Sprintf("%.2f MB/s", dlutil.GetSpeed(completedLength, p.start)/(1024*1024))),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressCurrentProgressPrefix, nil)),
|
||||
styling.Bold(fmt.Sprintf("%.2f%%", float64(percent))),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(task.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDone implements ProgressTracker.
|
||||
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
|
||||
logger := log.FromContext(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Infof("Aria2 task %s was canceled", task.TaskID())
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
|
||||
"TaskID": task.TaskID(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
logger.Errorf("Aria2 task %s failed: %s", task.TaskID(), err)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
|
||||
"Error": err.Error(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("Aria2 task %s completed successfully", task.TaskID())
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressAria2Done, map[string]any{
|
||||
"GID": task.GID,
|
||||
})),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
|
||||
); err != nil {
|
||||
logger.Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
var _ ProgressTracker = (*Progress)(nil)
|
||||
|
||||
func NewProgress(msgID int, userID int64) ProgressTracker {
|
||||
return &Progress{
|
||||
msgID: msgID,
|
||||
chatID: userID,
|
||||
}
|
||||
}
|
||||
61
core/tasks/aria2dl/task.go
Normal file
61
core/tasks/aria2dl/task.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
var _ core.Executable = (*Task)(nil)
|
||||
|
||||
type Task struct {
|
||||
ID string
|
||||
ctx context.Context
|
||||
GID string
|
||||
URIs []string
|
||||
Aria2Client *aria2.Client
|
||||
Storage storage.Storage
|
||||
StorPath string
|
||||
Progress ProgressTracker
|
||||
}
|
||||
|
||||
// Title implements core.Executable.
|
||||
func (t *Task) Title() string {
|
||||
return fmt.Sprintf("[%s](Aria2 GID:%s->%s:%s)", t.Type(), t.GID, t.Storage.Name(), t.StorPath)
|
||||
}
|
||||
|
||||
// Type implements core.Executable.
|
||||
func (t *Task) Type() tasktype.TaskType {
|
||||
return tasktype.TaskTypeAria2
|
||||
}
|
||||
|
||||
// TaskID implements core.Executable.
|
||||
func (t *Task) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func NewTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
gid string,
|
||||
uris []string,
|
||||
aria2Client *aria2.Client,
|
||||
stor storage.Storage,
|
||||
storPath string,
|
||||
progressTracker ProgressTracker,
|
||||
) *Task {
|
||||
return &Task{
|
||||
ID: id,
|
||||
ctx: ctx,
|
||||
GID: gid,
|
||||
URIs: uris,
|
||||
Aria2Client: aria2Client,
|
||||
Storage: stor,
|
||||
StorPath: storPath,
|
||||
Progress: progressTracker,
|
||||
}
|
||||
}
|
||||
209
core/tasks/aria2dl/task_test.go
Normal file
209
core/tasks/aria2dl/task_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package aria2dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
storconfig "github.com/krau/SaveAny-Bot/config/storage"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
)
|
||||
|
||||
type mockStorage struct {
|
||||
name string
|
||||
savePath string
|
||||
}
|
||||
|
||||
func (m *mockStorage) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *mockStorage) Type() storenum.StorageType {
|
||||
return storenum.StorageType("mock")
|
||||
}
|
||||
|
||||
func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error {
|
||||
m.savePath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStorage) Exists(ctx context.Context, path string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockStorage) JoinStoragePath(path string) string {
|
||||
return path
|
||||
}
|
||||
|
||||
type mockProgress struct {
|
||||
started bool
|
||||
done bool
|
||||
doneErr error
|
||||
progress int
|
||||
}
|
||||
|
||||
func (m *mockProgress) OnStart(ctx context.Context, task *Task) {
|
||||
m.started = true
|
||||
}
|
||||
|
||||
func (m *mockProgress) OnProgress(ctx context.Context, task *Task, status *aria2.Status) {
|
||||
m.progress++
|
||||
}
|
||||
|
||||
func (m *mockProgress) OnDone(ctx context.Context, task *Task, err error) {
|
||||
m.done = true
|
||||
m.doneErr = err
|
||||
}
|
||||
|
||||
func TestTaskCreation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
mockProg := &mockProgress{}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil,
|
||||
mockStor,
|
||||
"/test/path",
|
||||
mockProg,
|
||||
)
|
||||
|
||||
if task.ID != "test-task-id" {
|
||||
t.Errorf("Expected task ID to be 'test-task-id', got '%s'", task.ID)
|
||||
}
|
||||
|
||||
if task.GID != "test-gid" {
|
||||
t.Errorf("Expected GID to be 'test-gid', got '%s'", task.GID)
|
||||
}
|
||||
|
||||
if task.Type() != tasktype.TaskTypeAria2 {
|
||||
t.Errorf("Expected task type to be TaskTypeAria2, got '%s'", task.Type())
|
||||
}
|
||||
|
||||
if task.TaskID() != "test-task-id" {
|
||||
t.Errorf("Expected TaskID() to return 'test-task-id', got '%s'", task.TaskID())
|
||||
}
|
||||
|
||||
if task.Storage.Name() != "test-storage" {
|
||||
t.Errorf("Expected storage name to be 'test-storage', got '%s'", task.Storage.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTracker(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
mockProg := &mockProgress{}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil,
|
||||
mockStor,
|
||||
"/test/path",
|
||||
mockProg,
|
||||
)
|
||||
|
||||
// Test OnStart
|
||||
mockProg.OnStart(ctx, task)
|
||||
if !mockProg.started {
|
||||
t.Error("Expected OnStart to set started to true")
|
||||
}
|
||||
|
||||
// Test OnProgress
|
||||
status := &aria2.Status{
|
||||
GID: "test-gid",
|
||||
Status: "active",
|
||||
TotalLength: "1000000",
|
||||
CompletedLength: "500000",
|
||||
DownloadSpeed: "100000",
|
||||
}
|
||||
mockProg.OnProgress(ctx, task, status)
|
||||
if mockProg.progress != 1 {
|
||||
t.Errorf("Expected progress to be 1, got %d", mockProg.progress)
|
||||
}
|
||||
|
||||
// Test OnDone
|
||||
mockProg.OnDone(ctx, task, nil)
|
||||
if !mockProg.done {
|
||||
t.Error("Expected OnDone to set done to true")
|
||||
}
|
||||
if mockProg.doneErr != nil {
|
||||
t.Errorf("Expected doneErr to be nil, got %v", mockProg.doneErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskTitle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid-123",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil,
|
||||
mockStor,
|
||||
"/test/path",
|
||||
nil,
|
||||
)
|
||||
|
||||
title := task.Title()
|
||||
expectedSubstr := "test-gid-123"
|
||||
if len(title) == 0 {
|
||||
t.Error("Expected title to not be empty")
|
||||
}
|
||||
|
||||
// Check if title contains the GID
|
||||
found := false
|
||||
for i := 0; i < len(title)-len(expectedSubstr)+1; i++ {
|
||||
if title[i:i+len(expectedSubstr)] == expectedSubstr {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected title to contain GID '%s', got '%s'", expectedSubstr, title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockStor := &mockStorage{name: "test-storage"}
|
||||
mockProg := &mockProgress{}
|
||||
|
||||
task := NewTask(
|
||||
"test-task-id",
|
||||
ctx,
|
||||
"test-gid",
|
||||
[]string{"http://example.com/file.zip"},
|
||||
nil, // nil client will cause Execute to fail/timeout
|
||||
mockStor,
|
||||
"/test/path",
|
||||
mockProg,
|
||||
)
|
||||
|
||||
// Just verify the task structure is valid
|
||||
if task.ctx.Err() != nil {
|
||||
t.Error("Context should not be cancelled yet")
|
||||
}
|
||||
|
||||
// Wait for context to timeout
|
||||
<-ctx.Done()
|
||||
if ctx.Err() == nil {
|
||||
t.Error("Context should be cancelled after timeout")
|
||||
}
|
||||
}
|
||||
182
core/tasks/ytdlp/execute.go
Normal file
182
core/tasks/ytdlp/execute.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package ytdlp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
ytdlp "github.com/lrstanley/go-ytdlp"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
)
|
||||
|
||||
// Execute implements core.Executable.
|
||||
func (t *Task) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Infof("Starting yt-dlp download task %s", t.ID)
|
||||
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnStart(ctx, t)
|
||||
}
|
||||
|
||||
// Create temporary directory for downloads
|
||||
tempDir, err := os.MkdirTemp(config.C().Temp.BasePath, "ytdlp-*")
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create temp directory: %v", err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir) // Clean up temp directory
|
||||
|
||||
logger.Debugf("Created temp directory: %s", tempDir)
|
||||
|
||||
// Download files using yt-dlp
|
||||
downloadedFiles, err := t.downloadFiles(ctx, tempDir)
|
||||
if err != nil {
|
||||
logger.Errorf("yt-dlp download failed: %v", err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if len(downloadedFiles) == 0 {
|
||||
err := errors.New("no files were downloaded")
|
||||
logger.Error(err.Error())
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Transfer downloaded files to storage
|
||||
logger.Infof("Transferring %d file(s) to storage %s", len(downloadedFiles), t.Storage.Name())
|
||||
for _, filePath := range downloadedFiles {
|
||||
if err := t.transferFile(ctx, filePath); err != nil {
|
||||
logger.Errorf("File transfer failed: %v", err)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof("yt-dlp task %s completed successfully", t.ID)
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnDone(ctx, t, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadFiles downloads files using yt-dlp and returns the list of downloaded file paths
|
||||
func (t *Task) downloadFiles(ctx context.Context, tempDir string) ([]string, error) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
// Configure yt-dlp command
|
||||
cmd := ytdlp.New().
|
||||
FormatSort("res,ext:mp4:m4a").
|
||||
RecodeVideo("mp4").
|
||||
Output(filepath.Join(tempDir, "%(title)s.%(ext)s")).
|
||||
RestrictFilenames()
|
||||
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t, "Downloading...")
|
||||
}
|
||||
|
||||
// Execute download with URLs as arguments
|
||||
logger.Infof("Executing yt-dlp for %d URL(s)", len(t.URLs))
|
||||
|
||||
// Run with context for cancellation support
|
||||
result, err := cmd.Run(ctx, t.URLs...)
|
||||
if err != nil {
|
||||
// Check if context was canceled
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("yt-dlp execution failed: %w", err)
|
||||
}
|
||||
|
||||
if result.ExitCode != 0 {
|
||||
return nil, fmt.Errorf("yt-dlp exited with code %d: %s", result.ExitCode, result.Stderr)
|
||||
}
|
||||
|
||||
// List downloaded files
|
||||
files, err := os.ReadDir(tempDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read temp directory: %w", err)
|
||||
}
|
||||
|
||||
var downloadedFiles []string
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
fullPath := filepath.Join(tempDir, file.Name())
|
||||
downloadedFiles = append(downloadedFiles, fullPath)
|
||||
logger.Debugf("Downloaded file: %s", file.Name())
|
||||
}
|
||||
|
||||
return downloadedFiles, nil
|
||||
}
|
||||
|
||||
// transferFile transfers a single file to storage
|
||||
func (t *Task) transferFile(ctx context.Context, filePath string) error {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
// Check if file exists
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
logger.Warnf("Downloaded file not found: %s", filePath)
|
||||
return nil // Not a fatal error
|
||||
}
|
||||
return fmt.Errorf("failed to stat file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Open file
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file %s: %w", filePath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Set content length in context for storage
|
||||
ctx = context.WithValue(ctx, ctxkey.ContentLength, fileInfo.Size())
|
||||
|
||||
// Save to storage
|
||||
fileName := filepath.Base(filePath)
|
||||
// Remove special characters from filename if needed
|
||||
fileName = sanitizeFilename(fileName)
|
||||
destPath := filepath.Join(t.StorPath, fileName)
|
||||
|
||||
logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)
|
||||
|
||||
if err := t.Storage.Save(ctx, f, destPath); err != nil {
|
||||
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
|
||||
}
|
||||
|
||||
logger.Infof("Successfully transferred file %s", fileName)
|
||||
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t, fmt.Sprintf("Transferred: %s", fileName))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeFilename removes or replaces problematic characters in filenames
|
||||
func sanitizeFilename(name string) string {
|
||||
// yt-dlp with --restrict-filenames should already handle most cases
|
||||
// but we can do additional sanitization if needed
|
||||
name = strings.ReplaceAll(name, ":", "_")
|
||||
name = strings.ReplaceAll(name, "\"", "'")
|
||||
return name
|
||||
}
|
||||
183
core/tasks/ytdlp/progress.go
Normal file
183
core/tasks/ytdlp/progress.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package ytdlp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/gotd/td/telegram/message/entity"
|
||||
"github.com/gotd/td/telegram/message/styling"
|
||||
"github.com/gotd/td/tg"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
|
||||
)
|
||||
|
||||
// ProgressTracker defines the interface for tracking ytdlp task progress
|
||||
type ProgressTracker interface {
|
||||
OnStart(ctx context.Context, task *Task)
|
||||
OnProgress(ctx context.Context, task *Task, status string)
|
||||
OnDone(ctx context.Context, task *Task, err error)
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
msgID int
|
||||
chatID int64
|
||||
start time.Time
|
||||
lastUpdate atomic.Value // stores time.Time
|
||||
minUpdateInterval time.Duration
|
||||
}
|
||||
|
||||
// OnStart implements ProgressTracker.
|
||||
func (p *Progress) OnStart(ctx context.Context, task *Task) {
|
||||
logger := log.FromContext(ctx)
|
||||
p.start = time.Now()
|
||||
p.lastUpdate.Store(time.Now())
|
||||
p.minUpdateInterval = 2 * time.Second // Avoid too frequent updates
|
||||
logger.Infof("yt-dlp task started: message_id=%d, chat_id=%d, urls=%d", p.msgID, p.chatID, len(task.URLs))
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext == nil {
|
||||
return
|
||||
}
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpStart, map[string]any{
|
||||
"Count": len(task.URLs),
|
||||
})),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(task.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
|
||||
// OnProgress implements ProgressTracker.
|
||||
func (p *Progress) OnProgress(ctx context.Context, task *Task, status string) {
|
||||
// Throttle updates to avoid flooding Telegram API
|
||||
lastUpdateTime := p.lastUpdate.Load().(time.Time)
|
||||
if time.Since(lastUpdateTime) < p.minUpdateInterval {
|
||||
return
|
||||
}
|
||||
p.lastUpdate.Store(time.Now())
|
||||
|
||||
log.FromContext(ctx).Debugf("yt-dlp progress update: %s", status)
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpDownloading, map[string]any{
|
||||
"Count": len(task.URLs),
|
||||
})),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
|
||||
styling.Plain("\n\n"),
|
||||
styling.Plain(status),
|
||||
); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
req.SetReplyMarkup(&tg.ReplyInlineMarkup{
|
||||
Rows: []tg.KeyboardButtonRow{
|
||||
{
|
||||
Buttons: []tg.KeyboardButtonClass{
|
||||
tgutil.BuildCancelButton(task.TaskID()),
|
||||
},
|
||||
},
|
||||
}},
|
||||
)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDone implements ProgressTracker.
|
||||
func (p *Progress) OnDone(ctx context.Context, task *Task, err error) {
|
||||
logger := log.FromContext(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Infof("yt-dlp task %s was canceled", task.TaskID())
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
Message: i18n.T(i18nk.BotMsgProgressTaskCanceledWithId, map[string]any{
|
||||
"TaskID": task.TaskID(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
logger.Errorf("yt-dlp task %s failed: %s", task.TaskID(), err)
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
Message: i18n.T(i18nk.BotMsgProgressTaskFailedWithError, map[string]any{
|
||||
"Error": err.Error(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("yt-dlp task %s completed successfully", task.TaskID())
|
||||
|
||||
entityBuilder := entity.Builder{}
|
||||
if err := styling.Perform(&entityBuilder,
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressYtdlpDone, map[string]any{
|
||||
"Count": len(task.URLs),
|
||||
})),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgProgressSavePathPrefix, nil)),
|
||||
styling.Code(fmt.Sprintf("[%s]:%s", task.Storage.Name(), task.StorPath)),
|
||||
); err != nil {
|
||||
logger.Errorf("Failed to build entities: %s", err)
|
||||
return
|
||||
}
|
||||
text, entities := entityBuilder.Complete()
|
||||
req := &tg.MessagesEditMessageRequest{
|
||||
ID: p.msgID,
|
||||
}
|
||||
req.SetMessage(text)
|
||||
req.SetEntities(entities)
|
||||
|
||||
ext := tgutil.ExtFromContext(ctx)
|
||||
if ext != nil {
|
||||
ext.EditMessage(p.chatID, req)
|
||||
}
|
||||
}
|
||||
|
||||
var _ ProgressTracker = (*Progress)(nil)
|
||||
|
||||
func NewProgress(msgID int, userID int64) ProgressTracker {
|
||||
return &Progress{
|
||||
msgID: msgID,
|
||||
chatID: userID,
|
||||
minUpdateInterval: 2 * time.Second,
|
||||
}
|
||||
}
|
||||
58
core/tasks/ytdlp/task.go
Normal file
58
core/tasks/ytdlp/task.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package ytdlp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
var _ core.Executable = (*Task)(nil)
|
||||
|
||||
type Task struct {
|
||||
ID string
|
||||
ctx context.Context
|
||||
URLs []string
|
||||
Storage storage.Storage
|
||||
StorPath string
|
||||
Progress ProgressTracker
|
||||
}
|
||||
|
||||
// Title implements core.Executable.
|
||||
func (t *Task) Title() string {
|
||||
urlCount := len(t.URLs)
|
||||
if urlCount == 1 {
|
||||
return fmt.Sprintf("[%s](%s->%s:%s)", t.Type(), t.URLs[0], t.Storage.Name(), t.StorPath)
|
||||
}
|
||||
return fmt.Sprintf("[%s](%d URLs->%s:%s)", t.Type(), urlCount, t.Storage.Name(), t.StorPath)
|
||||
}
|
||||
|
||||
// Type implements core.Executable.
|
||||
func (t *Task) Type() tasktype.TaskType {
|
||||
return tasktype.TaskTypeYtdlp
|
||||
}
|
||||
|
||||
// TaskID implements core.Executable.
|
||||
func (t *Task) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func NewTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
urls []string,
|
||||
stor storage.Storage,
|
||||
storPath string,
|
||||
progressTracker ProgressTracker,
|
||||
) *Task {
|
||||
return &Task{
|
||||
ID: id,
|
||||
ctx: ctx,
|
||||
URLs: urls,
|
||||
Storage: stor,
|
||||
StorPath: storPath,
|
||||
Progress: progressTracker,
|
||||
}
|
||||
}
|
||||
257
docs/content/en/usage/api.md
Normal file
257
docs/content/en/usage/api.md
Normal file
@@ -0,0 +1,257 @@
|
||||
---
|
||||
title: "HTTP API"
|
||||
weight: 4
|
||||
---
|
||||
|
||||
# HTTP API
|
||||
|
||||
SaveAny-Bot provides a RESTful HTTP API for programmatic file downloads from Telegram.
|
||||
|
||||
## Configuration
|
||||
|
||||
Enable the API in your `config.toml`:
|
||||
|
||||
```toml
|
||||
[api]
|
||||
# Enable HTTP API service
|
||||
enable = true
|
||||
# API server listen port
|
||||
port = 8080
|
||||
# API access token (leave empty to disable authentication)
|
||||
token = "your-secret-token-here"
|
||||
# Task completion callback webhook URL (leave empty to disable)
|
||||
webhook_url = "https://your-server.com/webhook"
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
If `token` is configured, all API requests (except `/health`) must include an `Authorization` header:
|
||||
|
||||
```
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### Health Check
|
||||
|
||||
Check if the API server is running.
|
||||
|
||||
**Request:**
|
||||
```
|
||||
GET /health
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
### Create Download Task
|
||||
|
||||
Create a new file download task from a Telegram message link.
|
||||
|
||||
**Request:**
|
||||
```
|
||||
POST /api/v1/tasks
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer your-secret-token-here
|
||||
|
||||
{
|
||||
"telegram_url": "https://t.me/channel/123",
|
||||
"user_id": 123456789,
|
||||
"storage_name": "local1",
|
||||
"dir_path": "/downloads"
|
||||
}
|
||||
```
|
||||
|
||||
**Request Parameters:**
|
||||
- `telegram_url` (required): Telegram message link (e.g., `https://t.me/channel/123`)
|
||||
- `user_id` (required): Telegram user ID (must be configured in `config.toml`)
|
||||
- `storage_name` (optional): Storage name to use. If not specified, uses the first available storage for the user
|
||||
- `dir_path` (optional): Directory path in storage. Default is `/`
|
||||
|
||||
**Response (201 Created):**
|
||||
```json
|
||||
{
|
||||
"task_id": "c9h8t1234abcd",
|
||||
"message": "task created successfully"
|
||||
}
|
||||
```
|
||||
|
||||
**Error Response (4xx/5xx):**
|
||||
```json
|
||||
{
|
||||
"error": "error description"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Task Status
|
||||
|
||||
Get the status of a specific task.
|
||||
|
||||
**Request:**
|
||||
```
|
||||
GET /api/v1/tasks/{task_id}
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
**Response (200 OK):**
|
||||
```json
|
||||
{
|
||||
"task_id": "c9h8t1234abcd",
|
||||
"status": "completed",
|
||||
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
|
||||
"created_at": "2024-01-19T04:30:00Z",
|
||||
"error": ""
|
||||
}
|
||||
```
|
||||
|
||||
**Status Values:**
|
||||
- `queued`: Task is waiting in queue
|
||||
- `running`: Task is currently downloading
|
||||
- `completed`: Task completed successfully
|
||||
- `failed`: Task failed with error (see `error` field)
|
||||
- `canceled`: Task was canceled
|
||||
|
||||
### List All Tasks
|
||||
|
||||
List all queued and running tasks.
|
||||
|
||||
**Request:**
|
||||
```
|
||||
GET /api/v1/tasks
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
**Response (200 OK):**
|
||||
```json
|
||||
{
|
||||
"queued": [
|
||||
{
|
||||
"id": "c9h8t1234abcd",
|
||||
"title": "[tgfiles](file1.pdf->local1:/downloads/file1.pdf)"
|
||||
}
|
||||
],
|
||||
"running": [
|
||||
{
|
||||
"id": "d2k9u5678efgh",
|
||||
"title": "[tgfiles](file2.pdf->local1:/downloads/file2.pdf)"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Cancel Task
|
||||
|
||||
Cancel a running or queued task.
|
||||
|
||||
**Request:**
|
||||
```
|
||||
DELETE /api/v1/tasks/{task_id}
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
**Response (200 OK):**
|
||||
```json
|
||||
{
|
||||
"message": "task canceled"
|
||||
}
|
||||
```
|
||||
|
||||
## Webhook Callback
|
||||
|
||||
If `webhook_url` is configured, the API will send a POST request to the webhook URL when a task completes or fails.
|
||||
|
||||
**Webhook Request:**
|
||||
```
|
||||
POST {webhook_url}
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer your-secret-token-here
|
||||
|
||||
{
|
||||
"task_id": "c9h8t1234abcd",
|
||||
"status": "completed",
|
||||
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
|
||||
"created_at": "2024-01-19T04:30:00Z",
|
||||
"error": ""
|
||||
}
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Using cURL
|
||||
|
||||
**Create a download task:**
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/v1/tasks \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-secret-token-here" \
|
||||
-d '{
|
||||
"telegram_url": "https://t.me/channel/123",
|
||||
"user_id": 123456789,
|
||||
"storage_name": "local1",
|
||||
"dir_path": "/downloads"
|
||||
}'
|
||||
```
|
||||
|
||||
**Get task status:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
|
||||
-H "Authorization: Bearer your-secret-token-here"
|
||||
```
|
||||
|
||||
**List all tasks:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/v1/tasks \
|
||||
-H "Authorization: Bearer your-secret-token-here"
|
||||
```
|
||||
|
||||
**Cancel a task:**
|
||||
```bash
|
||||
curl -X DELETE http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
|
||||
-H "Authorization: Bearer your-secret-token-here"
|
||||
```
|
||||
|
||||
### Using Python
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_URL = "http://localhost:8080"
|
||||
TOKEN = "your-secret-token-here"
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {TOKEN}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Create a download task
|
||||
response = requests.post(
|
||||
f"{API_URL}/api/v1/tasks",
|
||||
headers=HEADERS,
|
||||
json={
|
||||
"telegram_url": "https://t.me/channel/123",
|
||||
"user_id": 123456789,
|
||||
"storage_name": "local1",
|
||||
"dir_path": "/downloads"
|
||||
}
|
||||
)
|
||||
task_id = response.json()["task_id"]
|
||||
|
||||
# Get task status
|
||||
response = requests.get(
|
||||
f"{API_URL}/api/v1/tasks/{task_id}",
|
||||
headers=HEADERS
|
||||
)
|
||||
status = response.json()
|
||||
print(f"Task status: {status['status']}")
|
||||
```
|
||||
|
||||
## Security Recommendations
|
||||
|
||||
1. **Always use a strong token** for production environments
|
||||
2. **Use HTTPS** in production by placing the API behind a reverse proxy (e.g., Nginx, Caddy)
|
||||
3. **Keep logs secure** as they may contain sensitive information
|
||||
4. **Validate user permissions** - ensure `user_id` in requests corresponds to authorized users in your config
|
||||
257
docs/content/zh/usage/api.md
Normal file
257
docs/content/zh/usage/api.md
Normal file
@@ -0,0 +1,257 @@
|
||||
---
|
||||
title: "HTTP API"
|
||||
weight: 4
|
||||
---
|
||||
|
||||
# HTTP API
|
||||
|
||||
SaveAny-Bot 提供 RESTful HTTP API,支持通过编程方式从 Telegram 下载文件。
|
||||
|
||||
## 配置
|
||||
|
||||
在 `config.toml` 中启用 API:
|
||||
|
||||
```toml
|
||||
[api]
|
||||
# 启用 HTTP API 服务
|
||||
enable = true
|
||||
# API 服务监听端口
|
||||
port = 8080
|
||||
# API 访问令牌 (留空则不验证)
|
||||
token = "your-secret-token-here"
|
||||
# 任务完成回调 Webhook URL (留空则不回调)
|
||||
webhook_url = "https://your-server.com/webhook"
|
||||
```
|
||||
|
||||
## 认证
|
||||
|
||||
如果配置了 `token`,所有 API 请求(除了 `/health`)都必须包含 `Authorization` 头:
|
||||
|
||||
```
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
## 端点
|
||||
|
||||
### 健康检查
|
||||
|
||||
检查 API 服务器是否正在运行。
|
||||
|
||||
**请求:**
|
||||
```
|
||||
GET /health
|
||||
```
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"status": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
### 创建下载任务
|
||||
|
||||
从 Telegram 消息链接创建新的文件下载任务。
|
||||
|
||||
**请求:**
|
||||
```
|
||||
POST /api/v1/tasks
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer your-secret-token-here
|
||||
|
||||
{
|
||||
"telegram_url": "https://t.me/channel/123",
|
||||
"user_id": 123456789,
|
||||
"storage_name": "local1",
|
||||
"dir_path": "/downloads"
|
||||
}
|
||||
```
|
||||
|
||||
**请求参数:**
|
||||
- `telegram_url` (必填): Telegram 消息链接 (例如: `https://t.me/channel/123`)
|
||||
- `user_id` (必填): Telegram 用户 ID (必须在 `config.toml` 中配置)
|
||||
- `storage_name` (可选): 要使用的存储名称。如果未指定,使用用户的第一个可用存储
|
||||
- `dir_path` (可选): 存储中的目录路径。默认为 `/`
|
||||
|
||||
**响应 (201 Created):**
|
||||
```json
|
||||
{
|
||||
"task_id": "c9h8t1234abcd",
|
||||
"message": "task created successfully"
|
||||
}
|
||||
```
|
||||
|
||||
**错误响应 (4xx/5xx):**
|
||||
```json
|
||||
{
|
||||
"error": "错误描述"
|
||||
}
|
||||
```
|
||||
|
||||
### 获取任务状态
|
||||
|
||||
获取特定任务的状态。
|
||||
|
||||
**请求:**
|
||||
```
|
||||
GET /api/v1/tasks/{task_id}
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
**响应 (200 OK):**
|
||||
```json
|
||||
{
|
||||
"task_id": "c9h8t1234abcd",
|
||||
"status": "completed",
|
||||
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
|
||||
"created_at": "2024-01-19T04:30:00Z",
|
||||
"error": ""
|
||||
}
|
||||
```
|
||||
|
||||
**状态值:**
|
||||
- `queued`: 任务正在队列中等待
|
||||
- `running`: 任务正在下载
|
||||
- `completed`: 任务成功完成
|
||||
- `failed`: 任务失败(查看 `error` 字段)
|
||||
- `canceled`: 任务已取消
|
||||
|
||||
### 列出所有任务
|
||||
|
||||
列出所有排队和正在运行的任务。
|
||||
|
||||
**请求:**
|
||||
```
|
||||
GET /api/v1/tasks
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
**响应 (200 OK):**
|
||||
```json
|
||||
{
|
||||
"queued": [
|
||||
{
|
||||
"id": "c9h8t1234abcd",
|
||||
"title": "[tgfiles](file1.pdf->local1:/downloads/file1.pdf)"
|
||||
}
|
||||
],
|
||||
"running": [
|
||||
{
|
||||
"id": "d2k9u5678efgh",
|
||||
"title": "[tgfiles](file2.pdf->local1:/downloads/file2.pdf)"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 取消任务
|
||||
|
||||
取消正在运行或排队的任务。
|
||||
|
||||
**请求:**
|
||||
```
|
||||
DELETE /api/v1/tasks/{task_id}
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
**响应 (200 OK):**
|
||||
```json
|
||||
{
|
||||
"message": "task canceled"
|
||||
}
|
||||
```
|
||||
|
||||
## Webhook 回调
|
||||
|
||||
如果配置了 `webhook_url`,API 会在任务完成或失败时向 webhook URL 发送 POST 请求。
|
||||
|
||||
**Webhook 请求:**
|
||||
```
|
||||
POST {webhook_url}
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer your-secret-token-here
|
||||
|
||||
{
|
||||
"task_id": "c9h8t1234abcd",
|
||||
"status": "completed",
|
||||
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
|
||||
"created_at": "2024-01-19T04:30:00Z",
|
||||
"error": ""
|
||||
}
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 使用 cURL
|
||||
|
||||
**创建下载任务:**
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/v1/tasks \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-secret-token-here" \
|
||||
-d '{
|
||||
"telegram_url": "https://t.me/channel/123",
|
||||
"user_id": 123456789,
|
||||
"storage_name": "local1",
|
||||
"dir_path": "/downloads"
|
||||
}'
|
||||
```
|
||||
|
||||
**获取任务状态:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
|
||||
-H "Authorization: Bearer your-secret-token-here"
|
||||
```
|
||||
|
||||
**列出所有任务:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/v1/tasks \
|
||||
-H "Authorization: Bearer your-secret-token-here"
|
||||
```
|
||||
|
||||
**取消任务:**
|
||||
```bash
|
||||
curl -X DELETE http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
|
||||
-H "Authorization: Bearer your-secret-token-here"
|
||||
```
|
||||
|
||||
### 使用 Python
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_URL = "http://localhost:8080"
|
||||
TOKEN = "your-secret-token-here"
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {TOKEN}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 创建下载任务
|
||||
response = requests.post(
|
||||
f"{API_URL}/api/v1/tasks",
|
||||
headers=HEADERS,
|
||||
json={
|
||||
"telegram_url": "https://t.me/channel/123",
|
||||
"user_id": 123456789,
|
||||
"storage_name": "local1",
|
||||
"dir_path": "/downloads"
|
||||
}
|
||||
)
|
||||
task_id = response.json()["task_id"]
|
||||
|
||||
# 获取任务状态
|
||||
response = requests.get(
|
||||
f"{API_URL}/api/v1/tasks/{task_id}",
|
||||
headers=HEADERS
|
||||
)
|
||||
status = response.json()
|
||||
print(f"任务状态: {status['status']}")
|
||||
```
|
||||
|
||||
## 安全建议
|
||||
|
||||
1. **生产环境始终使用强令牌**
|
||||
2. **生产环境使用 HTTPS**,通过反向代理(如 Nginx、Caddy)放置 API
|
||||
3. **保护日志安全**,因为它们可能包含敏感信息
|
||||
4. **验证用户权限** - 确保请求中的 `user_id` 对应于配置中的授权用户
|
||||
3
go.mod
3
go.mod
@@ -17,6 +17,7 @@ require (
|
||||
github.com/gotd/td v0.137.0
|
||||
github.com/johannesboyne/gofakes3 v0.0.0-20250916175020-ebf3e50324d3
|
||||
github.com/krau/ffmpeg-go v0.6.0
|
||||
github.com/lrstanley/go-ytdlp v1.2.7
|
||||
github.com/minio/minio-go/v7 v7.0.98
|
||||
github.com/playwright-community/playwright-go v0.5200.1
|
||||
github.com/rs/xid v1.6.0
|
||||
@@ -31,6 +32,7 @@ require (
|
||||
|
||||
require (
|
||||
github.com/AnimeKaizoku/cacher v1.0.3 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/aws/smithy-go v1.24.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -42,6 +44,7 @@ require (
|
||||
github.com/clipperhouse/displaywidth v0.7.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/coder/websocket v1.8.14 // indirect
|
||||
github.com/deckarep/golang-set/v2 v2.8.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -4,6 +4,8 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
|
||||
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
||||
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
|
||||
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
|
||||
@@ -66,6 +68,8 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
@@ -176,6 +180,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/krau/ffmpeg-go v0.6.0 h1:F4HWvOrKXQsfLsFTOnUfP0HY6WISJqOrsAFGSIzkKto=
|
||||
github.com/krau/ffmpeg-go v0.6.0/go.mod h1:sa7/bWHB6fO9j4lhmxnWQ1U07o+dE1leFjhctotxU7A=
|
||||
github.com/lrstanley/go-ytdlp v1.2.7 h1:YNDvKkd0OCJSZLZePZvJwcirBCfL8Yw3eCwrTCE5w7Q=
|
||||
github.com/lrstanley/go-ytdlp v1.2.7/go.mod h1:38IL64XM6gULrWtKTiR0+TTNCVbxesNSbTyaFG2CGTI=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package tasktype
|
||||
|
||||
//go:generate go-enum --values --names --flag --nocase
|
||||
// ENUM(tgfiles,tphpics,parseditem,directlinks)
|
||||
// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp)
|
||||
type TaskType string
|
||||
|
||||
@@ -20,6 +20,10 @@ const (
|
||||
TaskTypeParseditem TaskType = "parseditem"
|
||||
// TaskTypeDirectlinks is a TaskType of type directlinks.
|
||||
TaskTypeDirectlinks TaskType = "directlinks"
|
||||
// TaskTypeAria2 is a TaskType of type aria2.
|
||||
TaskTypeAria2 TaskType = "aria2"
|
||||
// TaskTypeYtdlp is a TaskType of type ytdlp.
|
||||
TaskTypeYtdlp TaskType = "ytdlp"
|
||||
)
|
||||
|
||||
var ErrInvalidTaskType = fmt.Errorf("not a valid TaskType, try [%s]", strings.Join(_TaskTypeNames, ", "))
|
||||
@@ -29,6 +33,8 @@ var _TaskTypeNames = []string{
|
||||
string(TaskTypeTphpics),
|
||||
string(TaskTypeParseditem),
|
||||
string(TaskTypeDirectlinks),
|
||||
string(TaskTypeAria2),
|
||||
string(TaskTypeYtdlp),
|
||||
}
|
||||
|
||||
// TaskTypeNames returns a list of possible string values of TaskType.
|
||||
@@ -45,6 +51,8 @@ func TaskTypeValues() []TaskType {
|
||||
TaskTypeTphpics,
|
||||
TaskTypeParseditem,
|
||||
TaskTypeDirectlinks,
|
||||
TaskTypeAria2,
|
||||
TaskTypeYtdlp,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +73,8 @@ var _TaskTypeValue = map[string]TaskType{
|
||||
"tphpics": TaskTypeTphpics,
|
||||
"parseditem": TaskTypeParseditem,
|
||||
"directlinks": TaskTypeDirectlinks,
|
||||
"aria2": TaskTypeAria2,
|
||||
"ytdlp": TaskTypeYtdlp,
|
||||
}
|
||||
|
||||
// ParseTaskType attempts to convert a string to a TaskType.
|
||||
|
||||
@@ -45,6 +45,10 @@ type Add struct {
|
||||
ParsedItem *parser.Item
|
||||
// directlinks
|
||||
DirectLinks []string
|
||||
// aria2
|
||||
Aria2URIs []string
|
||||
// ytdlp
|
||||
YtdlpURLs []string
|
||||
}
|
||||
|
||||
type SetDefaultStorage struct {
|
||||
|
||||
Reference in New Issue
Block a user