Compare commits

..

13 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
ef1ecf960f Fix album (NEW-FOR-ALBUM) rule handling and proper filename logic
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 07:32:52 +00:00
copilot-swe-agent[bot]
e1b4087801 Add storage rule support to API task creation
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 07:17:35 +00:00
copilot-swe-agent[bot]
6896bdc852 Fix performance issues, add media group support, and improve filename handling
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 06:14:48 +00:00
copilot-swe-agent[bot]
3a6402a71b Add validation to require API token when API is enabled
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 05:49:12 +00:00
copilot-swe-agent[bot]
173a5e3733 Remove IP whitelist mechanism, keep only token authentication
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 05:42:14 +00:00
copilot-swe-agent[bot]
20a5e317ae Address PR feedback: remove redundant files, format code, add progress tracking
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 05:27:12 +00:00
copilot-swe-agent[bot]
127901fd24 Final improvements: better user ID validation, safer IP handling, context-aware logging
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 04:52:43 +00:00
copilot-swe-agent[bot]
30c165033e Address all security review comments: sanitize remaining error messages and handle errors properly
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 04:50:48 +00:00
copilot-swe-agent[bot]
9dcb5201e1 Fix security issues: sanitize error messages and fix test port
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 04:49:03 +00:00
copilot-swe-agent[bot]
7b0142ef82 Add API tests and test script
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 04:46:08 +00:00
copilot-swe-agent[bot]
2f6b2470a4 Add API documentation in English and Chinese
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 04:44:20 +00:00
copilot-swe-agent[bot]
ac10c32215 Add HTTP API server for file downloads from Telegram links
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
2026-01-19 04:42:55 +00:00
copilot-swe-agent[bot]
7def7f5b28 Initial plan 2026-01-19 04:32:43 +00:00
33 changed files with 1422 additions and 351 deletions

629
api/handlers.go Normal file
View File

@@ -0,0 +1,629 @@
package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"
"path"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/client/bot"
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/ruleutil"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
tftask "github.com/krau/SaveAny-Bot/core/tasks/tfile"
"github.com/krau/SaveAny-Bot/database"
"github.com/krau/SaveAny-Bot/pkg/queue"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
// Request/Response types
type CreateTaskRequest struct {
TelegramURL string `json:"telegram_url"`
StorageName string `json:"storage_name,omitempty"`
DirPath string `json:"dir_path,omitempty"`
UserID int64 `json:"user_id"`
}
type CreateTaskResponse struct {
TaskID string `json:"task_id"`
Message string `json:"message"`
}
type TaskStatusResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"` // queued, running, completed, failed, canceled
Title string `json:"title"`
CreatedAt time.Time `json:"created_at"`
Error string `json:"error,omitempty"`
Downloaded int64 `json:"downloaded,omitempty"` // Bytes downloaded
Total int64 `json:"total,omitempty"` // Total bytes
ProgressPct float64 `json:"progress_pct,omitempty"` // Progress percentage (0-100)
}
type ListTasksResponse struct {
Queued []TaskInfo `json:"queued"`
Running []TaskInfo `json:"running"`
}
type TaskInfo struct {
ID string `json:"id"`
Title string `json:"title"`
}
type ErrorResponse struct {
Error string `json:"error"`
}
// Task tracking
var (
taskStatuses = make(map[string]*taskStatus)
taskStatusesMu sync.RWMutex
)
type taskStatus struct {
ID string
Status string
Title string
CreatedAt time.Time
Error string
Downloaded atomic.Int64 // Use atomic for lock-free updates
Total atomic.Int64 // Use atomic for lock-free updates
ProgressPct uint64 // Store as uint64 bits of float64 for atomic access
}
func handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func handleCreateTask(w http.ResponseWriter, r *http.Request) {
var req CreateTaskRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondError(w, "invalid request body", http.StatusBadRequest)
return
}
// Validate request
if req.TelegramURL == "" {
respondError(w, "telegram_url is required", http.StatusBadRequest)
return
}
if req.UserID <= 0 {
respondError(w, "user_id is required and must be positive", http.StatusBadRequest)
return
}
logger := log.FromContext(r.Context()).WithPrefix("api")
// Get user from database
userDB, err := database.GetUserByChatID(r.Context(), req.UserID)
if err != nil {
logger.Errorf("Failed to get user: %v", err)
respondError(w, "user not found", http.StatusBadRequest)
return
}
// Get storage
var stor storage.Storage
if req.StorageName != "" {
stor, err = storage.GetStorageByUserIDAndName(r.Context(), req.UserID, req.StorageName)
if err != nil {
logger.Errorf("Failed to get storage: %v", err)
respondError(w, "storage not found", http.StatusBadRequest)
return
}
} else {
// Use first available storage for the user
storages := storage.GetUserStorages(r.Context(), req.UserID)
if len(storages) == 0 {
respondError(w, "no storage available for user", http.StatusBadRequest)
return
}
stor = storages[0]
}
// Parse Telegram URL
botCtx := bot.ExtContext()
if botCtx == nil {
respondError(w, "bot not initialized", http.StatusInternalServerError)
return
}
linkUrl, err := url.Parse(req.TelegramURL)
if err != nil {
logger.Errorf("Failed to parse URL: %v", err)
respondError(w, "invalid telegram URL format", http.StatusBadRequest)
return
}
chatID, msgID, err := tgutil.ParseMessageLink(botCtx, req.TelegramURL)
if err != nil {
logger.Errorf("Failed to parse Telegram URL: %v", err)
respondError(w, "invalid telegram URL format", http.StatusBadRequest)
return
}
// Get message from Telegram
msg, err := tgutil.GetMessageByID(botCtx, chatID, msgID)
if err != nil {
logger.Errorf("Failed to get message: %v", err)
respondError(w, "failed to retrieve message", http.StatusBadRequest)
return
}
// Check if message has media
media, ok := msg.GetMedia()
if !ok {
respondError(w, "message has no media", http.StatusBadRequest)
return
}
// Collect files - handle both single and grouped messages
files := make([]tfile.TGFileMessage, 0)
// Check for grouped messages (media group)
groupID, isGroup := msg.GetGroupedID()
if isGroup && groupID != 0 && !linkUrl.Query().Has("single") {
// Handle media group
gmsgs, err := tgutil.GetGroupedMessages(botCtx, chatID, msg)
if err != nil {
logger.Errorf("Failed to get grouped messages: %v", err)
// Fall back to single message
file, err := createTGFileWithMedia(botCtx, msg, media, userDB)
if err != nil {
logger.Errorf("Failed to create TGFile: %v", err)
respondError(w, "invalid message format", http.StatusBadRequest)
return
}
files = append(files, file)
} else {
// Process all messages in the group
for _, gmsg := range gmsgs {
if gmsg.Media == nil {
continue
}
gMedia, ok := gmsg.GetMedia()
if !ok {
continue
}
file, err := createTGFileWithMedia(botCtx, gmsg, gMedia, userDB)
if err != nil {
logger.Warnf("Failed to create TGFile for grouped message: %v", err)
continue
}
files = append(files, file)
}
}
} else {
// Single message
file, err := createTGFileWithMedia(botCtx, msg, media, userDB)
if err != nil {
logger.Errorf("Failed to create TGFile: %v", err)
respondError(w, "invalid message format", http.StatusBadRequest)
return
}
files = append(files, file)
}
if len(files) == 0 {
respondError(w, "no savable files found", http.StatusBadRequest)
return
}
// Create tasks for all files with proper album handling
taskIDs := make([]string, 0, len(files))
baseDirPath := req.DirPath
if baseDirPath == "" {
baseDirPath = "/"
}
// Create context with bot extension
injectCtx := tgutil.ExtWithContext(r.Context(), botCtx)
// Apply storage rules if enabled for the user
useRule := userDB.ApplyRule && userDB.Rules != nil
// Helper to apply rules to a file
applyRule := func(file tfile.TGFileMessage) (storage.Storage, ruleutil.MatchedDirPath) {
fileStor := stor
dirPath := ruleutil.MatchedDirPath(baseDirPath)
if useRule {
matched, matchedStorName, matchedDirPath := ruleutil.ApplyRule(injectCtx, userDB.Rules, ruleutil.NewInput(file))
if matched {
// Rule matched, apply overrides
if matchedDirPath != "" {
dirPath = matchedDirPath
}
if matchedStorName.Usable() {
var err error
fileStor, err = storage.GetStorageByUserIDAndName(injectCtx, userDB.ChatID, matchedStorName.String())
if err != nil {
logger.Errorf("Failed to get storage from rule: %v", err)
// Fall back to original storage
fileStor = stor
}
}
}
}
return fileStor, dirPath
}
// Separate files into regular and album files
type albumFile struct {
file tfile.TGFileMessage
storage storage.Storage
dirPath ruleutil.MatchedDirPath
}
albumFiles := make(map[int64][]albumFile)
for _, tgFile := range files {
fileStor, dirPath := applyRule(tgFile)
// Check if this needs album handling (NEW-FOR-ALBUM rule)
if dirPath.NeedNewForAlbum() {
groupId, isGroup := tgFile.Message().GetGroupedID()
if !isGroup || groupId == 0 {
logger.Warnf("File %s has NEW-FOR-ALBUM rule but is not in a group, treating as regular file", tgFile.Name())
// Treat as regular file with base dir path
storagePath := fileStor.JoinStoragePath(path.Join(baseDirPath, tgFile.Name()))
taskID := xid.New().String()
task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, fileStor, storagePath, &apiProgressTracker{
taskID: taskID,
})
if err != nil {
logger.Errorf("Failed to create task: %v", err)
respondError(w, "failed to create task", http.StatusInternalServerError)
return
}
trackTask(taskID, task.Title(), "queued")
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %v", err)
updateTaskStatus(taskID, "failed", err.Error())
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
return
}
taskIDs = append(taskIDs, taskID)
continue
}
// Group by album ID
if _, ok := albumFiles[groupId]; !ok {
albumFiles[groupId] = make([]albumFile, 0)
}
albumFiles[groupId] = append(albumFiles[groupId], albumFile{
file: tgFile,
storage: fileStor,
dirPath: dirPath,
})
} else {
// Regular file - create task immediately
storagePath := fileStor.JoinStoragePath(path.Join(dirPath.String(), tgFile.Name()))
taskID := xid.New().String()
task, err := tftask.NewTGFileTask(taskID, injectCtx, tgFile, fileStor, storagePath, &apiProgressTracker{
taskID: taskID,
})
if err != nil {
logger.Errorf("Failed to create task: %v", err)
respondError(w, "failed to create task", http.StatusInternalServerError)
return
}
trackTask(taskID, task.Title(), "queued")
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %v", err)
updateTaskStatus(taskID, "failed", err.Error())
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
return
}
taskIDs = append(taskIDs, taskID)
}
}
// Handle album files - group them into subdirectories
for _, afiles := range albumFiles {
if len(afiles) <= 1 {
// Single file in album, treat as regular
for _, af := range afiles {
storagePath := af.storage.JoinStoragePath(path.Join(baseDirPath, af.file.Name()))
taskID := xid.New().String()
task, err := tftask.NewTGFileTask(taskID, injectCtx, af.file, af.storage, storagePath, &apiProgressTracker{
taskID: taskID,
})
if err != nil {
logger.Errorf("Failed to create task: %v", err)
respondError(w, "failed to create task", http.StatusInternalServerError)
return
}
trackTask(taskID, task.Title(), "queued")
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %v", err)
updateTaskStatus(taskID, "failed", err.Error())
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
return
}
taskIDs = append(taskIDs, taskID)
}
continue
}
// Multiple files in album - create subdirectory named after first file
// Remove extension from first file's name to use as directory name
albumDir := strings.TrimSuffix(path.Base(afiles[0].file.Name()), path.Ext(afiles[0].file.Name()))
albumStor := afiles[0].storage
for _, af := range afiles {
// All files go into the album subdirectory
storagePath := albumStor.JoinStoragePath(path.Join(baseDirPath, albumDir, af.file.Name()))
taskID := xid.New().String()
task, err := tftask.NewTGFileTask(taskID, injectCtx, af.file, albumStor, storagePath, &apiProgressTracker{
taskID: taskID,
})
if err != nil {
logger.Errorf("Failed to create task for album file: %v", err)
respondError(w, "failed to create task", http.StatusInternalServerError)
return
}
trackTask(taskID, task.Title(), "queued")
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add task: %v", err)
updateTaskStatus(taskID, "failed", err.Error())
respondError(w, "failed to add task to queue", http.StatusInternalServerError)
return
}
taskIDs = append(taskIDs, taskID)
}
}
// Send success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
// Return first task ID for single file, or all task IDs for media group
if len(taskIDs) == 1 {
json.NewEncoder(w).Encode(CreateTaskResponse{
TaskID: taskIDs[0],
Message: "task created successfully",
})
} else {
json.NewEncoder(w).Encode(map[string]interface{}{
"task_ids": taskIDs,
"message": fmt.Sprintf("%d tasks created successfully", len(taskIDs)),
})
}
}
// createTGFileWithMedia creates a TGFile with proper filename handling using user's strategy
func createTGFileWithMedia(botCtx *ext.Context, msg *tg.Message, media tg.MessageMediaClass, userDB *database.User) (tfile.TGFileMessage, error) {
// Use the same filename generation logic as bot handlers
opts := []tfile.TGFileOption{tfile.WithNameIfEmpty(tgutil.GenFileNameFromMessage(*msg))}
return tfile.FromMediaMessage(media, botCtx.Raw, msg, opts...)
}
func handleGetTask(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
if taskID == "" {
respondError(w, "task_id is required", http.StatusBadRequest)
return
}
taskStatusesMu.RLock()
status, exists := taskStatuses[taskID]
taskStatusesMu.RUnlock()
if !exists {
respondError(w, "task not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TaskStatusResponse{
TaskID: status.ID,
Status: status.Status,
Title: status.Title,
CreatedAt: status.CreatedAt,
Error: status.Error,
Downloaded: status.Downloaded.Load(),
Total: status.Total.Load(),
ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&status.ProgressPct))),
})
}
func handleListTasks(w http.ResponseWriter, r *http.Request) {
queued := core.GetQueuedTasks(r.Context())
running := core.GetRunningTasks(r.Context())
response := ListTasksResponse{
Queued: convertTaskInfos(queued),
Running: convertTaskInfos(running),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
func handleCancelTask(w http.ResponseWriter, r *http.Request) {
taskID := r.PathValue("id")
if taskID == "" {
respondError(w, "task_id is required", http.StatusBadRequest)
return
}
if err := core.CancelTask(r.Context(), taskID); err != nil {
log.FromContext(r.Context()).Errorf("Failed to cancel task %s: %v", taskID, err)
respondError(w, "failed to cancel task", http.StatusInternalServerError)
return
}
updateTaskStatus(taskID, "canceled", "")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "task canceled"})
}
// Helper functions
func respondError(w http.ResponseWriter, message string, statusCode int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(ErrorResponse{Error: message})
}
func trackTask(taskID, title, status string) {
taskStatusesMu.Lock()
defer taskStatusesMu.Unlock()
taskStatuses[taskID] = &taskStatus{
ID: taskID,
Status: status,
Title: title,
CreatedAt: time.Now(),
}
}
func updateTaskStatus(taskID, status, errorMsg string) {
taskStatusesMu.Lock()
defer taskStatusesMu.Unlock()
if ts, exists := taskStatuses[taskID]; exists {
ts.Status = status
ts.Error = errorMsg
}
}
func convertTaskInfos(tasks []queue.TaskInfo) []TaskInfo {
result := make([]TaskInfo, len(tasks))
for i, t := range tasks {
result[i] = TaskInfo{
ID: t.ID,
Title: t.Title,
}
}
return result
}
// apiProgressTracker implements tftask.ProgressTracker for API tasks
type apiProgressTracker struct {
taskID string
}
func (a *apiProgressTracker) OnStart(ctx context.Context, info tftask.TaskInfo) {
updateTaskStatus(a.taskID, "running", "")
}
func (a *apiProgressTracker) OnProgress(ctx context.Context, info tftask.TaskInfo, downloaded int64, total int64) {
// Use atomic operations to avoid mutex locks for better performance
// OnProgress is called very frequently during downloads
taskStatusesMu.RLock()
ts, exists := taskStatuses[a.taskID]
taskStatusesMu.RUnlock()
if exists {
ts.Downloaded.Store(downloaded)
ts.Total.Store(total)
if total > 0 {
progressPct := float64(downloaded) / float64(total) * 100.0
atomic.StoreUint64((*uint64)(&ts.ProgressPct), math.Float64bits(progressPct))
}
}
}
func (a *apiProgressTracker) OnDone(ctx context.Context, info tftask.TaskInfo, err error) {
if err != nil {
updateTaskStatus(a.taskID, "failed", err.Error())
sendWebhook(a.taskID, "failed", err.Error())
} else {
updateTaskStatus(a.taskID, "completed", "")
sendWebhook(a.taskID, "completed", "")
}
}
// sendWebhook sends a callback to the configured webhook URL
func sendWebhook(taskID, status, errorMsg string) {
cfg := config.C()
if cfg.API.WebhookURL == "" {
return
}
taskStatusesMu.RLock()
ts, exists := taskStatuses[taskID]
taskStatusesMu.RUnlock()
if !exists {
return
}
logger := log.WithPrefix("webhook")
payload := TaskStatusResponse{
TaskID: ts.ID,
Status: status,
Title: ts.Title,
CreatedAt: ts.CreatedAt,
Error: errorMsg,
Downloaded: ts.Downloaded.Load(),
Total: ts.Total.Load(),
ProgressPct: math.Float64frombits(atomic.LoadUint64((*uint64)(&ts.ProgressPct))),
}
body, err := json.Marshal(payload)
if err != nil {
logger.Errorf("Failed to marshal webhook payload: %v", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", cfg.API.WebhookURL, bytes.NewReader(body))
if err != nil {
logger.Errorf("Failed to create webhook request: %v", err)
return
}
req.Header.Set("Content-Type", "application/json")
if cfg.API.Token != "" {
req.Header.Set("Authorization", "Bearer "+cfg.API.Token)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Errorf("Failed to send webhook: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("Webhook returned error status %d, failed to read response body: %v", resp.StatusCode, err)
} else {
logger.Errorf("Webhook returned error status %d: %s", resp.StatusCode, string(body))
}
}
}

66
api/middleware.go Normal file
View File

@@ -0,0 +1,66 @@
package api
import (
"net/http"
"strings"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
)
// authMiddleware validates API token
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for health check
if r.URL.Path == "/health" {
next.ServeHTTP(w, r)
return
}
cfg := config.C()
// Check token if configured
if cfg.API.Token != "" {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error":"unauthorized: missing token"}`, http.StatusUnauthorized)
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token != cfg.API.Token {
http.Error(w, `{"error":"unauthorized: invalid token"}`, http.StatusUnauthorized)
return
}
}
next.ServeHTTP(w, r)
})
}
// loggingMiddleware logs HTTP requests
func loggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap response writer to capture status code
wrapper := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapper, r)
logger.Infof("%s %s %d %s", r.Method, r.URL.Path, wrapper.statusCode, time.Since(start))
})
}
}
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}

47
api/middleware_test.go Normal file
View File

@@ -0,0 +1,47 @@
package api
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestAuthMiddleware_NoAuth(t *testing.T) {
// Create a test handler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
// Create request
req := httptest.NewRequest("GET", "/api/v1/tasks", nil)
rec := httptest.NewRecorder()
// Apply middleware
authMiddleware(handler).ServeHTTP(rec, req)
// When no token is configured, request should succeed or be unauthorized
if rec.Code != http.StatusOK && rec.Code != http.StatusUnauthorized {
t.Errorf("Expected status 200 or 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_HealthCheck(t *testing.T) {
// Create a test handler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
// Create request to health endpoint
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
// Apply middleware
authMiddleware(handler).ServeHTTP(rec, req)
// Health check should always work without auth
if rec.Code != http.StatusOK {
t.Errorf("Expected status 200 for health check, got %d", rec.Code)
}
}

77
api/server.go Normal file
View File

@@ -0,0 +1,77 @@
package api
import (
"context"
"fmt"
"net/http"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
)
var server *http.Server
// Init initializes and starts the HTTP API server
func Init(ctx context.Context) error {
cfg := config.C()
if !cfg.API.Enable {
return nil
}
// Validate that token is configured when API is enabled
if cfg.API.Token == "" {
return fmt.Errorf("API is enabled but token is not configured. Please set 'api.token' in your configuration file for security")
}
logger := log.FromContext(ctx).WithPrefix("api")
mux := http.NewServeMux()
// Register API routes
registerRoutes(mux)
// Wrap with middleware
handler := loggingMiddleware(logger)(authMiddleware(mux))
server = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.API.Port),
Handler: handler,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
go func() {
logger.Infof("Starting API server on port %d", cfg.API.Port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Errorf("API server error: %v", err)
}
}()
// Graceful shutdown on context cancellation
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
logger.Errorf("Failed to shutdown API server: %v", err)
} else {
logger.Info("API server stopped")
}
}()
return nil
}
func registerRoutes(mux *http.ServeMux) {
// Health check endpoint (no auth required)
mux.HandleFunc("/health", handleHealth)
// API v1 endpoints
mux.HandleFunc("POST /api/v1/tasks", handleCreateTask)
mux.HandleFunc("GET /api/v1/tasks/{id}", handleGetTask)
mux.HandleFunc("GET /api/v1/tasks", handleListTasks)
mux.HandleFunc("DELETE /api/v1/tasks/{id}", handleCancelTask)
}

View File

@@ -100,7 +100,7 @@ func handleAddCallback(ctx *ext.Context, update *ext.Update) error {
}
shortcut.CreateAndAddAria2TaskWithEdit(ctx, selectedStorage, dirPath, data.Aria2URIs, client, msgID, userID)
case tasktype.TaskTypeYtdlp:
shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, data.YtdlpFlags, msgID, userID)
shortcut.CreateAndAddYtdlpTaskWithEdit(ctx, selectedStorage, dirPath, data.YtdlpURLs, msgID, userID)
default:
return fmt.Errorf("unexcept task type: %s", data.TaskType)
}

View File

@@ -84,7 +84,7 @@ func handleAria2DlCmd(ctx *ext.Context, update *ext.Update) error {
return nil
}
logger.Debug("Preparing aria2 download", "links", links)
// Initialize aria2 client to check connection
aria2ClientInitOnce.Do(func() {
aria2Client, aria2ClientInitErr = aria2.NewClient(config.C().Aria2.Url, config.C().Aria2.Secret)

View File

@@ -114,7 +114,7 @@ func processMediaGroup(ctx *ext.Context, update *ext.Update, groupID int64) {
if err != nil {
logger.Errorf("Failed to build storage selection keyboard: %s", err)
ctx.EditMessage(userId, &tg.MessagesEditMessageRequest{
ID: msg.ID,
ID: msg.ID,
Message: i18n.T(i18nk.BotMsgMediaGroupErrorBuildStorageSelectKeyboardFailed, map[string]any{
"Error": err.Error(),
}),

View File

@@ -38,7 +38,7 @@ func handleTaskCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextStyledTextArray([]styling.StyledTextOption{
styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)),
styling.Plain(i18n.T(i18nk.BotMsgTasksCancelRequestedPrefix)),
styling.Code(taskID),
}), nil)
default:

View File

@@ -103,7 +103,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error {
return err
}
ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{
ID: u.CallbackQuery.GetMsgID(),
ID: u.CallbackQuery.GetMsgID(),
Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradingWithVersion, map[string]any{
"Current": config.Version,
}),
@@ -111,7 +111,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error {
latest, err := ghselfupdate.UpdateSelf(currentV, config.GitRepo)
if err != nil {
ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{
ID: u.CallbackQuery.GetMsgID(),
ID: u.CallbackQuery.GetMsgID(),
Message: i18n.T(i18nk.BotMsgUpdateErrorUpgradeFailed, map[string]any{
"Error": err.Error(),
}),
@@ -119,7 +119,7 @@ func handleUpdateCallback(ctx *ext.Context, u *ext.Update) error {
return dispatcher.EndGroups
}
ctx.EditMessage(u.GetUserChat().GetID(), &tg.MessagesEditMessageRequest{
ID: u.CallbackQuery.GetMsgID(),
ID: u.CallbackQuery.GetMsgID(),
Message: i18n.T(i18nk.BotMsgUpdateInfoUpgradeSuccess, map[string]any{
"Version": latest.Version.String(),
}),

View File

@@ -112,7 +112,7 @@ func BuildFilenameTemplateData(message *tg.Message) map[string]string {
}(),
MsgRaw: message.GetMessage(),
ChatID: func() string {
// 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id,
// 如果消息是频道的(从消息链接中fetch的) 直接使用其chat id,
// 无论它是否是从其他来源转发的
if message.GetPost() {
peer := message.GetPeerID()

View File

@@ -50,9 +50,8 @@ func BuildAddSelectStorageKeyboard(stors []storage.Storage, adddata tcbdata.Add)
DirectLinks: adddata.DirectLinks,
Aria2URIs: adddata.Aria2URIs,
YtdlpURLs: adddata.YtdlpURLs,
YtdlpFlags: adddata.YtdlpFlags,
Aria2URIs: adddata.Aria2URIs,
YtdlpURLs: adddata.YtdlpURLs,
}
dataid := xid.New().String()
err := cache.Set(dataid, data)

View File

@@ -22,7 +22,7 @@ func CreateAndAddParsedTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirP
if err := core.AddTask(injectCtx, task); err != nil {
log.FromContext(ctx).Errorf("Failed to add task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: msgID,
ID: msgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),

View File

@@ -29,7 +29,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage
if err != nil {
logger.Errorf("Failed to get user by chat ID: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{
"Error": err.Error(),
}),
@@ -49,7 +49,7 @@ func CreateAndAddTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor storage
if err != nil {
logger.Errorf("Failed to get storage by user ID and name: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{
"Error": err.Error(),
}),
@@ -69,7 +69,7 @@ startCreateTask:
if err != nil {
logger.Errorf("create task failed: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
@@ -79,7 +79,7 @@ startCreateTask:
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("add task failed: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),
@@ -103,7 +103,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to get user by chat ID: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetUserWithErrFailed, map[string]any{
"Error": err.Error(),
}),
@@ -142,7 +142,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to get storage by user ID and name: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorGetStorageFailed, map[string]any{
"Error": err.Error(),
}),
@@ -156,10 +156,10 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to create task element: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
})
return dispatcher.EndGroups
}
@@ -193,7 +193,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err != nil {
logger.Errorf("Failed to create task element for album file: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskCreateFailed, map[string]any{
"Error": err.Error(),
}),
@@ -210,7 +210,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
if err := core.AddTask(injectCtx, task); err != nil {
logger.Errorf("Failed to add batch task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),
@@ -218,8 +218,8 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
return dispatcher.EndGroups
}
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonInfoBatchTasksAdded, map[string]any{
"Count": len(files),
}),
ReplyMarkup: nil,

View File

@@ -25,7 +25,7 @@ func CreateAndAddtelegraphWithEdit(
pics []string,
stor storage.Storage,
trackMsgID int) error {
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
task := tphtask.NewTask(xid.New().String(),
injectCtx,
@@ -39,7 +39,7 @@ func CreateAndAddtelegraphWithEdit(
if err := core.AddTask(injectCtx, task); err != nil {
log.FromContext(ctx).Errorf("Failed to add task: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
ID: trackMsgID,
ID: trackMsgID,
Message: i18n.T(i18nk.BotMsgCommonErrorTaskAddFailed, map[string]any{
"Error": err.Error(),
}),

View File

@@ -15,7 +15,7 @@ import (
"github.com/krau/SaveAny-Bot/storage"
)
func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, flags []string, msgID int, userID int64) error {
func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPath string, urls []string, msgID int, userID int64) error {
logger := log.FromContext(ctx)
injectCtx := tgutil.ExtWithContext(ctx.Context, ctx)
@@ -29,14 +29,13 @@ func CreateAndAddYtdlpTaskWithEdit(ctx *ext.Context, stor storage.Storage, dirPa
return dispatcher.EndGroups
}
logger.Infof("Creating yt-dlp task for %d URL(s) with %d flag(s)", len(urls), len(flags))
logger.Infof("Creating yt-dlp task for %d URL(s)", len(urls))
// Create yt-dlp task
task := ytdlp.NewTask(
xid.New().String(),
injectCtx,
urls,
flags,
stor,
stor.JoinStoragePath(dirPath),
ytdlp.NewProgress(msgID, userID),

View File

@@ -7,6 +7,7 @@ import (
"github.com/celestix/gotgproto/dispatcher"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/duke-git/lancet/v2/slice"
"github.com/krau/SaveAny-Bot/client/bot/handlers/utils/msgelem"
"github.com/krau/SaveAny-Bot/common/i18n"
@@ -24,59 +25,29 @@ func handleYtdlpCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups
}
// Separate URLs and flags from arguments
var urls []string
var flags []string
for i := 1; i < len(args); i++ {
arg := strings.TrimSpace(args[i])
if arg == "" {
continue
}
// Check if it's a flag (starts with - or --)
if strings.HasPrefix(arg, "-") {
flags = append(flags, arg)
// Check if the next argument might be a value for this flag
// Don't consume it if it starts with - or looks like a URL with scheme
if i+1 < len(args) {
nextArg := strings.TrimSpace(args[i+1])
if nextArg != "" && !strings.HasPrefix(nextArg, "-") {
// Check if it's clearly a URL (has ://)
// This handles common video URLs (http://, https://)
// For other yt-dlp inputs, users should ensure proper formatting
if strings.Contains(nextArg, "://") {
// It's a URL, don't consume it as a flag value
continue
}
// Otherwise, treat it as a flag value
flags = append(flags, nextArg)
i++ // Skip the next argument as it's been consumed
}
}
} else {
// Try to parse as URL
u, err := url.Parse(arg)
if err != nil || u.Scheme == "" || u.Host == "" {
logger.Warnf("Invalid URL: %s", arg)
continue
}
urls = append(urls, arg)
urls := args[1:]
// Validate and clean URLs
for i, link := range urls {
urls[i] = strings.TrimSpace(link)
u, err := url.Parse(link)
if err != nil || u.Scheme == "" || u.Host == "" {
logger.Warnf("Invalid URL: %s", link)
urls[i] = ""
}
}
urls = slice.Compact(urls)
if len(urls) == 0 {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgYtdlpErrorNoValidUrls)), nil)
return dispatcher.EndGroups
}
logger.Debugf("Preparing yt-dlp download for %d URL(s) with %d flag(s)", len(urls), len(flags))
logger.Debugf("Preparing yt-dlp download for %d URL(s)", len(urls))
// Build storage selection keyboard
markup, err := msgelem.BuildAddSelectStorageKeyboard(storage.GetUserStorages(ctx, update.GetUserChat().GetID()), tcbdata.Add{
TaskType: tasktype.TaskTypeYtdlp,
YtdlpURLs: urls,
YtdlpFlags: flags,
TaskType: tasktype.TaskTypeYtdlp,
YtdlpURLs: urls,
})
if err != nil {
return err

View File

@@ -1,129 +0,0 @@
package handlers
import (
"net/url"
"strings"
"testing"
)
// TestYtdlpArgumentParsing tests the URL and flag separation logic
func TestYtdlpArgumentParsing(t *testing.T) {
tests := []struct {
name string
input string
expectedURLs []string
expectedFlags []string
}{
{
name: "Single URL without flags",
input: "/ytdlp https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{},
},
{
name: "Multiple URLs without flags",
input: "/ytdlp https://example.com/v1 https://example.com/v2",
expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"},
expectedFlags: []string{},
},
{
name: "URL with format flag",
input: "/ytdlp --format best https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"--format", "best"},
},
{
name: "URL with extract-audio flag",
input: "/ytdlp --extract-audio --audio-format mp3 https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"--extract-audio", "--audio-format", "mp3"},
},
{
name: "Multiple URLs with flags",
input: "/ytdlp --format best https://example.com/v1 https://example.com/v2",
expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"},
expectedFlags: []string{"--format", "best"},
},
{
name: "Flags mixed with URLs",
input: "/ytdlp https://example.com/v1 --format best https://example.com/v2",
expectedURLs: []string{"https://example.com/v1", "https://example.com/v2"},
expectedFlags: []string{"--format", "best"},
},
{
name: "Short flag",
input: "/ytdlp -f best https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"-f", "best"},
},
{
name: "Boolean flag",
input: "/ytdlp --extract-audio https://example.com/video",
expectedURLs: []string{"https://example.com/video"},
expectedFlags: []string{"--extract-audio"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args := strings.Split(tt.input, " ")
// Simulate the parsing logic from handleYtdlpCmd
var urls []string
var flags []string
for i := 1; i < len(args); i++ {
arg := strings.TrimSpace(args[i])
if arg == "" {
continue
}
// Check if it's a flag (starts with - or --)
if strings.HasPrefix(arg, "-") {
flags = append(flags, arg)
// Check if the next argument might be a value for this flag
if i+1 < len(args) {
nextArg := strings.TrimSpace(args[i+1])
if nextArg != "" && !strings.HasPrefix(nextArg, "-") {
// Check if it's clearly a URL (has ://)
if strings.Contains(nextArg, "://") {
// It's a URL, don't consume it as a flag value
continue
}
// Otherwise, treat it as a flag value
flags = append(flags, nextArg)
i++ // Skip the next argument as it's been consumed
}
}
} else {
// Try to parse as URL
u, err := url.Parse(arg)
if err != nil || u.Scheme == "" || u.Host == "" {
continue
}
urls = append(urls, arg)
}
}
// Verify URLs
if len(urls) != len(tt.expectedURLs) {
t.Errorf("Expected %d URLs, got %d", len(tt.expectedURLs), len(urls))
}
for i, expectedURL := range tt.expectedURLs {
if i >= len(urls) || urls[i] != expectedURL {
t.Errorf("Expected URL[%d] to be '%s', got '%s'", i, expectedURL, urls[i])
}
}
// Verify flags
if len(flags) != len(tt.expectedFlags) {
t.Errorf("Expected %d flags, got %d", len(tt.expectedFlags), len(flags))
}
for i, expectedFlag := range tt.expectedFlags {
if i >= len(flags) || flags[i] != expectedFlag {
t.Errorf("Expected flag[%d] to be '%s', got '%s'", i, expectedFlag, flags[i])
}
}
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"slices"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/api"
"github.com/krau/SaveAny-Bot/client/bot"
userclient "github.com/krau/SaveAny-Bot/client/user"
"github.com/krau/SaveAny-Bot/common/cache"
@@ -76,7 +77,11 @@ func initAll(ctx context.Context, cmd *cobra.Command) (<-chan struct{}, error) {
logger.Fatal("User login failed", "error", err)
}
}
return bot.Init(ctx), nil
exitChan := bot.Init(ctx)
if err := api.Init(ctx); err != nil {
return nil, fmt.Errorf("failed to init API server: %w", err)
}
return exitChan, nil
}
func cleanCache() {

View File

@@ -289,7 +289,7 @@ bot:
error_no_valid_links: "No valid links to download"
info_files_select_storage: "Total {{.Count}} files, please select storage"
ytdlp:
usage: "Usage: /ytdlp [OPTIONS] <URL1> [URL2] ...\nExamples:\n /ytdlp https://example.com/video\n /ytdlp --format best https://example.com/video\n /ytdlp --extract-audio --audio-format mp3 https://example.com/video"
usage: "Usage: /ytdlp <URL1> <URL2> ..."
error_no_valid_urls: "No valid URLs"
info_urls_select_storage: "Found {{.Count}} links, please select storage"
info_downloading: "Downloading via yt-dlp..."

View File

@@ -290,7 +290,7 @@ bot:
error_no_valid_links: "没有有效的链接可供下载"
info_files_select_storage: "共 {{.Count}} 个文件, 请选择存储位置"
ytdlp:
usage: "用法: /ytdlp [选项] <URL1> [URL2] ...\n示例:\n /ytdlp https://example.com/video\n /ytdlp --format best https://example.com/video\n /ytdlp --extract-audio --audio-format mp3 https://example.com/video"
usage: "用法: /ytdlp <URL1> <URL2> ..."
error_no_valid_urls: "没有有效的 URL"
info_urls_select_storage: "共 {{.Count}} 个链接, 请选择存储位置"
info_downloading: "正在通过 yt-dlp 下载..."

View File

@@ -48,4 +48,4 @@ func NewProgressWriter(
wr: wr,
onWrite: onWrite,
}
}
}

View File

@@ -29,6 +29,17 @@ secret = ""
# 转存完成后删除 Aria2 下载的本地文件
remove_after_transfer = true
# HTTP API 配置
[api]
# 启用 HTTP API 服务
enable = false
# API 服务监听端口
port = 8080
# API 访问令牌 (留空则不验证)
token = ""
# 任务完成回调 Webhook URL (留空则不回调)
webhook_url = ""
# 存储列表
[[storages]]
# 标识名, 需要唯一

View File

@@ -33,6 +33,7 @@ type Config struct {
Storages []storage.StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
Parser parserConfig `toml:"parser" mapstructure:"parser" json:"parser"`
Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"`
API apiConfig `toml:"api" mapstructure:"api" json:"api"`
}
type aria2Config struct {
@@ -42,6 +43,13 @@ type aria2Config struct {
KeepFile bool `toml:"keep_file" mapstructure:"keep_file" json:"keep_file"`
}
type apiConfig struct {
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
Port int `toml:"port" mapstructure:"port" json:"port"`
Token string `toml:"token" mapstructure:"token" json:"token"`
WebhookURL string `toml:"webhook_url" mapstructure:"webhook_url" json:"webhook_url"`
}
var cfg = &Config{}
func C() Config {
@@ -115,6 +123,12 @@ func Init(ctx context.Context, configFile ...string) error {
// 数据库
"db.path": "data/saveany.db",
"db.session": "data/session.db",
// API
"api.enable": false,
"api.port": 8080,
"api.token": "",
"api.webhook_url": "",
}
for key, value := range defaultConfigs {

View File

@@ -80,34 +80,22 @@ func (t *Task) Execute(ctx context.Context) error {
func (t *Task) downloadFiles(ctx context.Context, tempDir string) ([]string, error) {
logger := log.FromContext(ctx)
// Configure yt-dlp command with essential settings
// Always set output path to ensure files go to temp directory
// Configure yt-dlp command
cmd := ytdlp.New().
Output(filepath.Join(tempDir, "%(title)s.%(ext)s"))
// If no custom flags are provided, use default behavior
if len(t.Flags) == 0 {
cmd = cmd.
FormatSort("res,ext:mp4:m4a").
RecodeVideo("mp4").
RestrictFilenames()
}
// Note: If custom flags are provided, users have full control over format/quality
// The output path is always set above to ensure downloads go to the correct directory
FormatSort("res,ext:mp4:m4a").
RecodeVideo("mp4").
Output(filepath.Join(tempDir, "%(title)s.%(ext)s")).
RestrictFilenames()
if t.Progress != nil {
t.Progress.OnProgress(ctx, t, "Downloading...")
}
// Execute download with URLs and custom flags
logger.Infof("Executing yt-dlp for %d URL(s) with %d custom flag(s)", len(t.URLs), len(t.Flags))
// Combine flags and URLs as arguments (flags first, then URLs)
// yt-dlp accepts: yt-dlp [OPTIONS] URL [URL...]
args := append(t.Flags, t.URLs...)
// Execute download with URLs as arguments
logger.Infof("Executing yt-dlp for %d URL(s)", len(t.URLs))
// Run with context for cancellation support
result, err := cmd.Run(ctx, args...)
result, err := cmd.Run(ctx, t.URLs...)
if err != nil {
// Check if context was canceled
if errors.Is(err, context.Canceled) {

View File

@@ -15,7 +15,6 @@ type Task struct {
ID string
ctx context.Context
URLs []string
Flags []string
Storage storage.Storage
StorPath string
Progress ProgressTracker
@@ -44,7 +43,6 @@ func NewTask(
id string,
ctx context.Context,
urls []string,
flags []string,
stor storage.Storage,
storPath string,
progressTracker ProgressTracker,
@@ -53,7 +51,6 @@ func NewTask(
ID: id,
ctx: ctx,
URLs: urls,
Flags: flags,
Storage: stor,
StorPath: storPath,
Progress: progressTracker,

View File

@@ -1,114 +0,0 @@
package ytdlp
import (
"context"
"io"
"testing"
storcfg "github.com/krau/SaveAny-Bot/config/storage"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
)
// MockStorage is a simple mock for testing
type MockStorage struct{}
func (m *MockStorage) Init(ctx context.Context, cfg storcfg.StorageConfig) error { return nil }
func (m *MockStorage) Type() storenum.StorageType { return "mock" }
func (m *MockStorage) Name() string { return "test-storage" }
func (m *MockStorage) JoinStoragePath(p string) string { return "test-path" }
func (m *MockStorage) Save(ctx context.Context, reader io.Reader, path string) error { return nil }
func (m *MockStorage) Exists(ctx context.Context, path string) bool { return false }
func TestNewTask(t *testing.T) {
ctx := context.Background()
urls := []string{"https://example.com/video"}
flags := []string{"--format", "best"}
stor := &MockStorage{}
storPath := "test-path"
task := NewTask("test-id", ctx, urls, flags, stor, storPath, nil)
if task == nil {
t.Fatal("NewTask returned nil")
}
if task.ID != "test-id" {
t.Errorf("Expected task ID 'test-id', got '%s'", task.ID)
}
if len(task.URLs) != 1 || task.URLs[0] != "https://example.com/video" {
t.Errorf("Expected URLs to contain 'https://example.com/video', got %v", task.URLs)
}
if len(task.Flags) != 2 || task.Flags[0] != "--format" || task.Flags[1] != "best" {
t.Errorf("Expected flags to contain '--format' and 'best', got %v", task.Flags)
}
if task.Storage.Name() != "test-storage" {
t.Errorf("Expected storage name 'test-storage', got '%s'", task.Storage.Name())
}
}
func TestNewTaskWithoutFlags(t *testing.T) {
ctx := context.Background()
urls := []string{"https://example.com/video1", "https://example.com/video2"}
var flags []string // No flags
stor := &MockStorage{}
storPath := "test-path"
task := NewTask("test-id-2", ctx, urls, flags, stor, storPath, nil)
if task == nil {
t.Fatal("NewTask returned nil")
}
if len(task.URLs) != 2 {
t.Errorf("Expected 2 URLs, got %d", len(task.URLs))
}
if len(task.Flags) != 0 {
t.Errorf("Expected 0 flags, got %d", len(task.Flags))
}
}
func TestTaskTitle(t *testing.T) {
ctx := context.Background()
stor := &MockStorage{}
// Test with single URL
task1 := NewTask("id1", ctx, []string{"https://example.com/video"}, nil, stor, "path", nil)
title1 := task1.Title()
if title1 == "" {
t.Error("Task title should not be empty")
}
// Test with multiple URLs
task2 := NewTask("id2", ctx, []string{"https://example.com/v1", "https://example.com/v2"}, nil, stor, "path", nil)
title2 := task2.Title()
if title2 == "" {
t.Error("Task title should not be empty")
}
}
func TestTaskType(t *testing.T) {
ctx := context.Background()
stor := &MockStorage{}
task := NewTask("id", ctx, []string{"https://example.com"}, nil, stor, "path", nil)
taskType := task.Type()
if taskType.String() != "ytdlp" {
t.Errorf("Expected task type 'ytdlp', got '%s'", taskType.String())
}
}
func TestTaskID(t *testing.T) {
ctx := context.Background()
stor := &MockStorage{}
expectedID := "test-task-id-123"
task := NewTask(expectedID, ctx, []string{"https://example.com"}, nil, stor, "path", nil)
if task.TaskID() != expectedID {
t.Errorf("Expected task ID '%s', got '%s'", expectedID, task.TaskID())
}
}

View File

@@ -49,4 +49,4 @@ func GetUserByID(ctx context.Context, id uint) (*User, error) {
Preload(clause.Associations).
Where("id = ?", id).First(&user).Error
return &user, err
}
}

View File

@@ -0,0 +1,257 @@
---
title: "HTTP API"
weight: 4
---
# HTTP API
SaveAny-Bot provides a RESTful HTTP API for programmatic file downloads from Telegram.
## Configuration
Enable the API in your `config.toml`:
```toml
[api]
# Enable HTTP API service
enable = true
# API server listen port
port = 8080
# API access token (leave empty to disable authentication)
token = "your-secret-token-here"
# Task completion callback webhook URL (leave empty to disable)
webhook_url = "https://your-server.com/webhook"
```
## Authentication
If `token` is configured, all API requests (except `/health`) must include an `Authorization` header:
```
Authorization: Bearer your-secret-token-here
```
## Endpoints
### Health Check
Check if the API server is running.
**Request:**
```
GET /health
```
**Response:**
```json
{
"status": "ok"
}
```
### Create Download Task
Create a new file download task from a Telegram message link.
**Request:**
```
POST /api/v1/tasks
Content-Type: application/json
Authorization: Bearer your-secret-token-here
{
"telegram_url": "https://t.me/channel/123",
"user_id": 123456789,
"storage_name": "local1",
"dir_path": "/downloads"
}
```
**Request Parameters:**
- `telegram_url` (required): Telegram message link (e.g., `https://t.me/channel/123`)
- `user_id` (required): Telegram user ID (must be configured in `config.toml`)
- `storage_name` (optional): Storage name to use. If not specified, uses the first available storage for the user
- `dir_path` (optional): Directory path in storage. Default is `/`
**Response (201 Created):**
```json
{
"task_id": "c9h8t1234abcd",
"message": "task created successfully"
}
```
**Error Response (4xx/5xx):**
```json
{
"error": "error description"
}
```
### Get Task Status
Get the status of a specific task.
**Request:**
```
GET /api/v1/tasks/{task_id}
Authorization: Bearer your-secret-token-here
```
**Response (200 OK):**
```json
{
"task_id": "c9h8t1234abcd",
"status": "completed",
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
"created_at": "2024-01-19T04:30:00Z",
"error": ""
}
```
**Status Values:**
- `queued`: Task is waiting in queue
- `running`: Task is currently downloading
- `completed`: Task completed successfully
- `failed`: Task failed with error (see `error` field)
- `canceled`: Task was canceled
### List All Tasks
List all queued and running tasks.
**Request:**
```
GET /api/v1/tasks
Authorization: Bearer your-secret-token-here
```
**Response (200 OK):**
```json
{
"queued": [
{
"id": "c9h8t1234abcd",
"title": "[tgfiles](file1.pdf->local1:/downloads/file1.pdf)"
}
],
"running": [
{
"id": "d2k9u5678efgh",
"title": "[tgfiles](file2.pdf->local1:/downloads/file2.pdf)"
}
]
}
```
### Cancel Task
Cancel a running or queued task.
**Request:**
```
DELETE /api/v1/tasks/{task_id}
Authorization: Bearer your-secret-token-here
```
**Response (200 OK):**
```json
{
"message": "task canceled"
}
```
## Webhook Callback
If `webhook_url` is configured, the API will send a POST request to the webhook URL when a task completes or fails.
**Webhook Request:**
```
POST {webhook_url}
Content-Type: application/json
Authorization: Bearer your-secret-token-here
{
"task_id": "c9h8t1234abcd",
"status": "completed",
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
"created_at": "2024-01-19T04:30:00Z",
"error": ""
}
```
## Example Usage
### Using cURL
**Create a download task:**
```bash
curl -X POST http://localhost:8080/api/v1/tasks \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-secret-token-here" \
-d '{
"telegram_url": "https://t.me/channel/123",
"user_id": 123456789,
"storage_name": "local1",
"dir_path": "/downloads"
}'
```
**Get task status:**
```bash
curl http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
-H "Authorization: Bearer your-secret-token-here"
```
**List all tasks:**
```bash
curl http://localhost:8080/api/v1/tasks \
-H "Authorization: Bearer your-secret-token-here"
```
**Cancel a task:**
```bash
curl -X DELETE http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
-H "Authorization: Bearer your-secret-token-here"
```
### Using Python
```python
import requests
API_URL = "http://localhost:8080"
TOKEN = "your-secret-token-here"
HEADERS = {
"Authorization": f"Bearer {TOKEN}",
"Content-Type": "application/json"
}
# Create a download task
response = requests.post(
f"{API_URL}/api/v1/tasks",
headers=HEADERS,
json={
"telegram_url": "https://t.me/channel/123",
"user_id": 123456789,
"storage_name": "local1",
"dir_path": "/downloads"
}
)
task_id = response.json()["task_id"]
# Get task status
response = requests.get(
f"{API_URL}/api/v1/tasks/{task_id}",
headers=HEADERS
)
status = response.json()
print(f"Task status: {status['status']}")
```
## Security Recommendations
1. **Always use a strong token** for production environments
2. **Use HTTPS** in production by placing the API behind a reverse proxy (e.g., Nginx, Caddy)
3. **Keep logs secure** as they may contain sensitive information
4. **Validate user permissions** - ensure `user_id` in requests corresponds to authorized users in your config

View File

@@ -0,0 +1,257 @@
---
title: "HTTP API"
weight: 4
---
# HTTP API
SaveAny-Bot 提供 RESTful HTTP API支持通过编程方式从 Telegram 下载文件。
## 配置
`config.toml` 中启用 API
```toml
[api]
# 启用 HTTP API 服务
enable = true
# API 服务监听端口
port = 8080
# API 访问令牌 (留空则不验证)
token = "your-secret-token-here"
# 任务完成回调 Webhook URL (留空则不回调)
webhook_url = "https://your-server.com/webhook"
```
## 认证
如果配置了 `token`,所有 API 请求(除了 `/health`)都必须包含 `Authorization` 头:
```
Authorization: Bearer your-secret-token-here
```
## 端点
### 健康检查
检查 API 服务器是否正在运行。
**请求:**
```
GET /health
```
**响应:**
```json
{
"status": "ok"
}
```
### 创建下载任务
从 Telegram 消息链接创建新的文件下载任务。
**请求:**
```
POST /api/v1/tasks
Content-Type: application/json
Authorization: Bearer your-secret-token-here
{
"telegram_url": "https://t.me/channel/123",
"user_id": 123456789,
"storage_name": "local1",
"dir_path": "/downloads"
}
```
**请求参数:**
- `telegram_url` (必填): Telegram 消息链接 (例如: `https://t.me/channel/123`)
- `user_id` (必填): Telegram 用户 ID (必须在 `config.toml` 中配置)
- `storage_name` (可选): 要使用的存储名称。如果未指定,使用用户的第一个可用存储
- `dir_path` (可选): 存储中的目录路径。默认为 `/`
**响应 (201 Created)**
```json
{
"task_id": "c9h8t1234abcd",
"message": "task created successfully"
}
```
**错误响应 (4xx/5xx)**
```json
{
"error": "错误描述"
}
```
### 获取任务状态
获取特定任务的状态。
**请求:**
```
GET /api/v1/tasks/{task_id}
Authorization: Bearer your-secret-token-here
```
**响应 (200 OK)**
```json
{
"task_id": "c9h8t1234abcd",
"status": "completed",
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
"created_at": "2024-01-19T04:30:00Z",
"error": ""
}
```
**状态值:**
- `queued`: 任务正在队列中等待
- `running`: 任务正在下载
- `completed`: 任务成功完成
- `failed`: 任务失败(查看 `error` 字段)
- `canceled`: 任务已取消
### 列出所有任务
列出所有排队和正在运行的任务。
**请求:**
```
GET /api/v1/tasks
Authorization: Bearer your-secret-token-here
```
**响应 (200 OK)**
```json
{
"queued": [
{
"id": "c9h8t1234abcd",
"title": "[tgfiles](file1.pdf->local1:/downloads/file1.pdf)"
}
],
"running": [
{
"id": "d2k9u5678efgh",
"title": "[tgfiles](file2.pdf->local1:/downloads/file2.pdf)"
}
]
}
```
### 取消任务
取消正在运行或排队的任务。
**请求:**
```
DELETE /api/v1/tasks/{task_id}
Authorization: Bearer your-secret-token-here
```
**响应 (200 OK)**
```json
{
"message": "task canceled"
}
```
## Webhook 回调
如果配置了 `webhook_url`API 会在任务完成或失败时向 webhook URL 发送 POST 请求。
**Webhook 请求:**
```
POST {webhook_url}
Content-Type: application/json
Authorization: Bearer your-secret-token-here
{
"task_id": "c9h8t1234abcd",
"status": "completed",
"title": "[tgfiles](file.pdf->local1:/downloads/file.pdf)",
"created_at": "2024-01-19T04:30:00Z",
"error": ""
}
```
## 使用示例
### 使用 cURL
**创建下载任务:**
```bash
curl -X POST http://localhost:8080/api/v1/tasks \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-secret-token-here" \
-d '{
"telegram_url": "https://t.me/channel/123",
"user_id": 123456789,
"storage_name": "local1",
"dir_path": "/downloads"
}'
```
**获取任务状态:**
```bash
curl http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
-H "Authorization: Bearer your-secret-token-here"
```
**列出所有任务:**
```bash
curl http://localhost:8080/api/v1/tasks \
-H "Authorization: Bearer your-secret-token-here"
```
**取消任务:**
```bash
curl -X DELETE http://localhost:8080/api/v1/tasks/c9h8t1234abcd \
-H "Authorization: Bearer your-secret-token-here"
```
### 使用 Python
```python
import requests
API_URL = "http://localhost:8080"
TOKEN = "your-secret-token-here"
HEADERS = {
"Authorization": f"Bearer {TOKEN}",
"Content-Type": "application/json"
}
# 创建下载任务
response = requests.post(
f"{API_URL}/api/v1/tasks",
headers=HEADERS,
json={
"telegram_url": "https://t.me/channel/123",
"user_id": 123456789,
"storage_name": "local1",
"dir_path": "/downloads"
}
)
task_id = response.json()["task_id"]
# 获取任务状态
response = requests.get(
f"{API_URL}/api/v1/tasks/{task_id}",
headers=HEADERS
)
status = response.json()
print(f"任务状态: {status['status']}")
```
## 安全建议
1. **生产环境始终使用强令牌**
2. **生产环境使用 HTTPS**,通过反向代理(如 Nginx、Caddy放置 API
3. **保护日志安全**,因为它们可能包含敏感信息
4. **验证用户权限** - 确保请求中的 `user_id` 对应于配置中的授权用户

View File

@@ -1,6 +1,5 @@
package ctxkey
// ENUM(content-length)
//
//go:generate go-enum --values --names --flag --nocase --noprefix
// ENUM(content-length)
type ContextKey string

View File

@@ -1,6 +1,5 @@
package tasktype
// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp)
//
//go:generate go-enum --values --names --flag --nocase
// ENUM(tgfiles,tphpics,parseditem,directlinks,aria2,ytdlp)
type TaskType string

View File

@@ -48,8 +48,7 @@ type Add struct {
// aria2
Aria2URIs []string
// ytdlp
YtdlpURLs []string
YtdlpFlags []string
YtdlpURLs []string
}
type SetDefaultStorage struct {

View File

@@ -36,4 +36,4 @@ func WithSizeIfZero(size int64) TGFileOption {
f.size = size
}
}
}
}